Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
79436248
"...docs/git@developer.sourcefind.cn:dcuai/dlexamples.git" did not exist on "76ccaa54e9a1aa224ffac27787498f7fab451bb6"
Commit
79436248
authored
Jun 03, 2021
by
WenmuZhou
Browse files
add table eval and predict script
parent
ad4853db
Changes
12
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
4060 additions
and
4 deletions
+4060
-4
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+172
-3
ppocr/utils/dict/table_dict.txt
ppocr/utils/dict/table_dict.txt
+278
-0
ppocr/utils/dict/table_structure_dict.txt
ppocr/utils/dict/table_structure_dict.txt
+2759
-0
ppocr/utils/table_utils/matcher.py
ppocr/utils/table_utils/matcher.py
+214
-0
ppstructure/predict_system.py
ppstructure/predict_system.py
+123
-0
ppstructure/table/__init__.py
ppstructure/table/__init__.py
+13
-0
ppstructure/table/eval_table.py
ppstructure/table/eval_table.py
+67
-0
ppstructure/table/predict_structure.py
ppstructure/table/predict_structure.py
+141
-0
ppstructure/table/predict_table.py
ppstructure/table/predict_table.py
+222
-0
ppstructure/table/table_metric/__init__.py
ppstructure/table/table_metric/__init__.py
+16
-0
ppstructure/table/table_metric/parallel.py
ppstructure/table/table_metric/parallel.py
+51
-0
tools/infer/utility.py
tools/infer/utility.py
+4
-1
No files found.
ppocr/postprocess/rec_postprocess.py
View file @
79436248
...
@@ -44,16 +44,16 @@ class BaseRecLabelDecode(object):
...
@@ -44,16 +44,16 @@ class BaseRecLabelDecode(object):
self
.
character_str
=
string
.
printable
[:
-
6
]
self
.
character_str
=
string
.
printable
[:
-
6
]
dict_character
=
list
(
self
.
character_str
)
dict_character
=
list
(
self
.
character_str
)
elif
character_type
in
support_character_type
:
elif
character_type
in
support_character_type
:
self
.
character_str
=
""
self
.
character_str
=
[]
assert
character_dict_path
is
not
None
,
"character_dict_path should not be None when character_type is {}"
.
format
(
assert
character_dict_path
is
not
None
,
"character_dict_path should not be None when character_type is {}"
.
format
(
character_type
)
character_type
)
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
lines
=
fin
.
readlines
()
for
line
in
lines
:
for
line
in
lines
:
line
=
line
.
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
line
=
line
.
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
self
.
character_str
+=
line
self
.
character_str
.
append
(
line
)
if
use_space_char
:
if
use_space_char
:
self
.
character_str
+=
" "
self
.
character_str
.
append
(
" "
)
dict_character
=
list
(
self
.
character_str
)
dict_character
=
list
(
self
.
character_str
)
else
:
else
:
...
@@ -288,3 +288,172 @@ class SRNLabelDecode(BaseRecLabelDecode):
...
@@ -288,3 +288,172 @@ class SRNLabelDecode(BaseRecLabelDecode):
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
\
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
%
beg_or_end
return
idx
return
idx
class
TableLabelDecode
(
object
):
""" """
def
__init__
(
self
,
max_text_length
,
max_elem_length
,
max_cell_num
,
character_dict_path
,
**
kwargs
):
self
.
max_text_length
=
max_text_length
self
.
max_elem_length
=
max_elem_length
self
.
max_cell_num
=
max_cell_num
list_character
,
list_elem
=
self
.
load_char_elem_dict
(
character_dict_path
)
list_character
=
self
.
add_special_char
(
list_character
)
list_elem
=
self
.
add_special_char
(
list_elem
)
self
.
dict_character
=
{}
self
.
dict_idx_character
=
{}
for
i
,
char
in
enumerate
(
list_character
):
self
.
dict_idx_character
[
i
]
=
char
self
.
dict_character
[
char
]
=
i
self
.
dict_elem
=
{}
self
.
dict_idx_elem
=
{}
for
i
,
elem
in
enumerate
(
list_elem
):
self
.
dict_idx_elem
[
i
]
=
elem
self
.
dict_elem
[
elem
]
=
i
def
load_char_elem_dict
(
self
,
character_dict_path
):
list_character
=
[]
list_elem
=
[]
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
substr
=
lines
[
0
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
split
(
"
\t
"
)
character_num
=
int
(
substr
[
0
])
elem_num
=
int
(
substr
[
1
])
for
cno
in
range
(
1
,
1
+
character_num
):
character
=
lines
[
cno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
)
list_character
.
append
(
character
)
for
eno
in
range
(
1
+
character_num
,
1
+
character_num
+
elem_num
):
elem
=
lines
[
eno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
)
list_elem
.
append
(
elem
)
return
list_character
,
list_elem
def
add_special_char
(
self
,
list_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
list_character
=
[
self
.
beg_str
]
+
list_character
+
[
self
.
end_str
]
return
list_character
def
get_sp_tokens
(
self
):
char_beg_idx
=
self
.
get_beg_end_flag_idx
(
'beg'
,
'char'
)
char_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'char'
)
elem_beg_idx
=
self
.
get_beg_end_flag_idx
(
'beg'
,
'elem'
)
elem_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'elem'
)
elem_char_idx1
=
self
.
dict_elem
[
'<td>'
]
elem_char_idx2
=
self
.
dict_elem
[
'<td'
]
sp_tokens
=
np
.
array
([
char_beg_idx
,
char_end_idx
,
elem_beg_idx
,
elem_end_idx
,
elem_char_idx1
,
elem_char_idx2
,
self
.
max_text_length
,
self
.
max_elem_length
,
self
.
max_cell_num
])
return
sp_tokens
def
__call__
(
self
,
preds
):
structure_probs
=
preds
[
'structure_probs'
]
loc_preds
=
preds
[
'loc_preds'
]
if
isinstance
(
structure_probs
,
paddle
.
Tensor
):
structure_probs
=
structure_probs
.
numpy
()
if
isinstance
(
loc_preds
,
paddle
.
Tensor
):
loc_preds
=
loc_preds
.
numpy
()
structure_idx
=
structure_probs
.
argmax
(
axis
=
2
)
structure_probs
=
structure_probs
.
max
(
axis
=
2
)
structure_str
,
structure_pos
,
result_score_list
,
result_elem_idx_list
=
self
.
decode
(
structure_idx
,
structure_probs
,
'elem'
)
res_html_code_list
=
[]
res_loc_list
=
[]
batch_num
=
len
(
structure_str
)
for
bno
in
range
(
batch_num
):
res_loc
=
[]
for
sno
in
range
(
len
(
structure_str
[
bno
])):
text
=
structure_str
[
bno
][
sno
]
if
text
in
[
'<td>'
,
'<td'
]:
pos
=
structure_pos
[
bno
][
sno
]
res_loc
.
append
(
loc_preds
[
bno
,
pos
])
res_html_code
=
''
.
join
(
structure_str
[
bno
])
res_loc
=
np
.
array
(
res_loc
)
res_html_code_list
.
append
(
res_html_code
)
res_loc_list
.
append
(
res_loc
)
return
{
'res_html_code'
:
res_html_code_list
,
'res_loc'
:
res_loc_list
,
'res_score_list'
:
result_score_list
,
'res_elem_idx_list'
:
result_elem_idx_list
,
'structure_str_list'
:
structure_str
}
def
decode
(
self
,
text_index
,
structure_probs
,
char_or_elem
):
"""convert text-label into text-index.
"""
if
char_or_elem
==
"char"
:
max_len
=
self
.
max_text_length
current_dict
=
self
.
dict_idx_character
else
:
max_len
=
self
.
max_elem_length
current_dict
=
self
.
dict_idx_elem
ignored_tokens
=
self
.
get_ignored_tokens
(
'elem'
)
beg_idx
,
end_idx
=
ignored_tokens
# select_td_tokens = []
# select_span_tokens = []
# for elem in self.dict_elem:
# # if elem == '<td>' or elem == '<td' or elem == '<tr>'\
# # or 'rowspan' in elem or 'colspan' in elem:
# if elem == '<td>' or elem == '<td' or elem == '<tr>':
# select_td_tokens.append(self.dict_elem[elem])
# if 'rowspan' in elem or 'colspan' in elem:
# select_span_tokens.append(self.dict_elem[elem])
result_list
=
[]
result_pos_list
=
[]
result_score_list
=
[]
result_elem_idx_list
=
[]
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
elem_pos_list
=
[]
elem_idx_list
=
[]
score_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
tmp_elem_idx
=
int
(
text_index
[
batch_idx
][
idx
])
if
idx
>
0
and
tmp_elem_idx
==
end_idx
:
break
if
tmp_elem_idx
in
ignored_tokens
:
continue
# if tmp_elem_idx in select_td_tokens:
# total_td_score += structure_probs[batch_idx, idx]
# total_td_num += 1
# if tmp_elem_idx in select_span_tokens:
# total_span_score += structure_probs[batch_idx, idx]
# total_span_num += 1
char_list
.
append
(
current_dict
[
tmp_elem_idx
])
elem_pos_list
.
append
(
idx
)
score_list
.
append
(
structure_probs
[
batch_idx
,
idx
])
elem_idx_list
.
append
(
tmp_elem_idx
)
result_list
.
append
(
char_list
)
result_pos_list
.
append
(
elem_pos_list
)
result_score_list
.
append
(
score_list
)
result_elem_idx_list
.
append
(
elem_idx_list
)
return
result_list
,
result_pos_list
,
result_score_list
,
result_elem_idx_list
def
get_ignored_tokens
(
self
,
char_or_elem
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
,
char_or_elem
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
,
char_or_elem
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
,
char_or_elem
):
if
char_or_elem
==
"char"
:
if
beg_or_end
==
"beg"
:
idx
=
self
.
dict_character
[
self
.
beg_str
]
elif
beg_or_end
==
"end"
:
idx
=
self
.
dict_character
[
self
.
end_str
]
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of char"
\
%
beg_or_end
elif
char_or_elem
==
"elem"
:
if
beg_or_end
==
"beg"
:
idx
=
self
.
dict_elem
[
self
.
beg_str
]
elif
beg_or_end
==
"end"
:
idx
=
self
.
dict_elem
[
self
.
end_str
]
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of elem"
\
%
beg_or_end
else
:
assert
False
,
"Unsupport type %s in char_or_elem"
\
%
char_or_elem
return
idx
ppocr/utils/dict/table_dict.txt
0 → 100644
View file @
79436248
←
</overline>
☆
─
α
⋅
$
ω
ψ
χ
(
υ
≥
σ
,
ρ
ε
0
■
4
8
✗
b
<
✓
Ψ
Ω
€
D
3
Π
H
║
</
>
L
Φ
Χ
θ
P
κ
λ
μ
T
ξ
X
β
γ
δ
\
ζ
η
`
d
<strike>
h
f
l
Θ
p
√
t
</sub>
x
Β
Γ
Δ
|
ǂ
ɛ
j
̧
➢
̌
′
«
△
▲
#
</b>
'
Ι
+
¶
/
▼
⇑
□
·
7
▪
;
?
➔
∩
C
÷
G
⇒
K
<sup>
O
S
С
W
Α
[
○
_
●
‡
c
z
g
<i>
o
<sub>
〈
〉
s
⩽
w
φ
ʹ
{
»
∣
̆
e
ˆ
∈
τ
◆
ι
∅
∆
∙
∘
Ø
ß
✔
∞
∑
−
×
◊
∗
∖
˃
˂
∫
"
i
&
π
↔
*
∥
æ
∧
.
⁄
ø
Q
∼
6
⁎
:
★
>
a
B
≈
F
J
̄
N
♯
R
V
<overline>
―
Z
♣
^
¤
¥
§
<underline>
¢
£
≦
≤
‖
Λ
©
n
↓
→
↑
r
°
±
v
<b>
♂
k
♀
~
ᅟ
̇
@
”
♦
ł
®
⊕
„
!
</sup>
%
⇓
)
-
1
5
9
=
А
A
‰
⋆
Σ
E
◦
I
※
M
m
̨
⩾
†
</i>
•
U
Y
]
̸
2
‐
–
‒
̂
—
̀
́
’
‘
⋮
⋯
̊
“
̈
≧
q
u
ı
y
</underline>
̃
}
ν
ppocr/utils/dict/table_structure_dict.txt
0 → 100644
View file @
79436248
This diff is collapsed.
Click to expand it.
ppocr/utils/table_utils/matcher.py
0 → 100755
View file @
79436248
import
json
def
distance
(
box_1
,
box_2
):
x1
,
y1
,
x2
,
y2
=
box_1
x3
,
y3
,
x4
,
y4
=
box_2
# min_x = (x1 + x2) / 2
# min_y = (y1 + y2) / 2
# max_x = (x3 + x4) / 2
# max_y = (y3 + y4) / 2
dis
=
abs
(
x3
-
x1
)
+
abs
(
y3
-
y1
)
+
abs
(
x4
-
x2
)
+
abs
(
y4
-
y2
)
dis_2
=
abs
(
x3
-
x1
)
+
abs
(
y3
-
y1
)
dis_3
=
abs
(
x4
-
x2
)
+
abs
(
y4
-
y2
)
#dis = pow(min_x - max_x, 2) + pow(min_y - max_y, 2) + pow(x3 - x1, 2) + pow(y3 - y1, 2) + pow(x4- x2, 2) + pow(y4 - y2, 2) + abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
return
dis
+
min
(
dis_2
,
dis_3
)
def
compute_iou
(
rec1
,
rec2
):
"""
computing IoU
:param rec1: (y0, x0, y1, x1), which reflects
(top, left, bottom, right)
:param rec2: (y0, x0, y1, x1)
:return: scala value of IoU
"""
# computing area of each rectangles
rec1
,
rec2
=
rec1
*
1000
,
rec2
*
1000
S_rec1
=
(
rec1
[
2
]
-
rec1
[
0
])
*
(
rec1
[
3
]
-
rec1
[
1
])
S_rec2
=
(
rec2
[
2
]
-
rec2
[
0
])
*
(
rec2
[
3
]
-
rec2
[
1
])
# computing the sum_area
sum_area
=
S_rec1
+
S_rec2
# find the each edge of intersect rectangle
left_line
=
max
(
rec1
[
1
],
rec2
[
1
])
right_line
=
min
(
rec1
[
3
],
rec2
[
3
])
top_line
=
max
(
rec1
[
0
],
rec2
[
0
])
bottom_line
=
min
(
rec1
[
2
],
rec2
[
2
])
# judge if there is an intersect
if
left_line
>=
right_line
or
top_line
>=
bottom_line
:
return
0
else
:
intersect
=
(
right_line
-
left_line
)
*
(
bottom_line
-
top_line
)
return
(
intersect
/
(
sum_area
-
intersect
))
*
1.0
def
matcher_merge
(
ocr_bboxes
,
pred_bboxes
):
# ocr_bboxes: OCR pred_bboxes:端到端
all_dis
=
[]
ious
=
[]
matched
=
{}
for
i
,
gt_box
in
enumerate
(
ocr_bboxes
):
distances
=
[]
for
j
,
pred_box
in
enumerate
(
pred_bboxes
):
distances
.
append
((
distance
(
gt_box
,
pred_box
),
1.
-
compute_iou
(
gt_box
,
pred_box
)))
#获取两两cell之间的L1距离和 1- IOU
sorted_distances
=
distances
.
copy
()
# 根据距离和IOU挑选最"近"的cell
sorted_distances
=
sorted
(
sorted_distances
,
key
=
lambda
item
:
(
item
[
1
],
item
[
0
]))
if
distances
.
index
(
sorted_distances
[
0
])
not
in
matched
.
keys
():
matched
[
distances
.
index
(
sorted_distances
[
0
])]
=
[
i
]
else
:
matched
[
distances
.
index
(
sorted_distances
[
0
])].
append
(
i
)
return
matched
#, sum(ious) / len(ious)
def
complex_num
(
pred_bboxes
):
complex_nums
=
[]
for
bbox
in
pred_bboxes
:
distances
=
[]
temp_ious
=
[]
for
pred_bbox
in
pred_bboxes
:
if
bbox
!=
pred_bbox
:
distances
.
append
(
distance
(
bbox
,
pred_bbox
))
temp_ious
.
append
(
compute_iou
(
bbox
,
pred_bbox
))
complex_nums
.
append
(
temp_ious
[
distances
.
index
(
min
(
distances
))])
return
sum
(
complex_nums
)
/
len
(
complex_nums
)
def
get_rows
(
pred_bboxes
):
pre_bbox
=
pred_bboxes
[
0
]
res
=
[]
step
=
0
for
i
in
range
(
len
(
pred_bboxes
)):
bbox
=
pred_bboxes
[
i
]
if
bbox
[
1
]
-
pre_bbox
[
1
]
>
2
or
bbox
[
0
]
-
pre_bbox
[
0
]
<
0
:
break
else
:
res
.
append
(
bbox
)
step
+=
1
for
i
in
range
(
step
):
pred_bboxes
.
pop
(
0
)
return
res
,
pred_bboxes
def
refine_rows
(
pred_bboxes
):
# 微调整行的框,使在一条水平线上
ys_1
=
[]
ys_2
=
[]
for
box
in
pred_bboxes
:
ys_1
.
append
(
box
[
1
])
ys_2
.
append
(
box
[
3
])
min_y_1
=
sum
(
ys_1
)
/
len
(
ys_1
)
min_y_2
=
sum
(
ys_2
)
/
len
(
ys_2
)
re_boxes
=
[]
for
box
in
pred_bboxes
:
box
[
1
]
=
min_y_1
box
[
3
]
=
min_y_2
re_boxes
.
append
(
box
)
return
re_boxes
def
matcher_refine_row
(
gt_bboxes
,
pred_bboxes
):
before_refine_pred_bboxes
=
pred_bboxes
.
copy
()
pred_bboxes
=
[]
while
(
len
(
before_refine_pred_bboxes
)
!=
0
):
row_bboxes
,
before_refine_pred_bboxes
=
get_rows
(
before_refine_pred_bboxes
)
print
(
row_bboxes
)
pred_bboxes
.
extend
(
refine_rows
(
row_bboxes
))
all_dis
=
[]
ious
=
[]
matched
=
{}
for
i
,
gt_box
in
enumerate
(
gt_bboxes
):
distances
=
[]
#temp_ious = []
for
j
,
pred_box
in
enumerate
(
pred_bboxes
):
distances
.
append
(
distance
(
gt_box
,
pred_box
))
#temp_ious.append(compute_iou(gt_box, pred_box))
#all_dis.append(min(distances))
#ious.append(temp_ious[distances.index(min(distances))])
if
distances
.
index
(
min
(
distances
))
not
in
matched
.
keys
():
matched
[
distances
.
index
(
min
(
distances
))]
=
[
i
]
else
:
matched
[
distances
.
index
(
min
(
distances
))].
append
(
i
)
return
matched
#, sum(ious) / len(ious)
#先挑选出一行,再进行匹配
def
matcher_structure_1
(
gt_bboxes
,
pred_bboxes_rows
,
pred_bboxes
):
gt_box_index
=
0
delete_gt_bboxes
=
gt_bboxes
.
copy
()
match_bboxes_ready
=
[]
matched
=
{}
while
(
len
(
delete_gt_bboxes
)
!=
0
):
row_bboxes
,
delete_gt_bboxes
=
get_rows
(
delete_gt_bboxes
)
row_bboxes
=
sorted
(
row_bboxes
,
key
=
lambda
key
:
key
[
0
])
if
len
(
pred_bboxes_rows
)
>
0
:
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
print
(
row_bboxes
)
for
i
,
gt_box
in
enumerate
(
row_bboxes
):
#print(gt_box)
pred_distances
=
[]
distances
=
[]
for
pred_bbox
in
pred_bboxes
:
pred_distances
.
append
(
distance
(
gt_box
,
pred_bbox
))
for
j
,
pred_box
in
enumerate
(
match_bboxes_ready
):
distances
.
append
(
distance
(
gt_box
,
pred_box
))
index
=
pred_distances
.
index
(
min
(
distances
))
#print('index', index)
if
index
not
in
matched
.
keys
():
matched
[
index
]
=
[
gt_box_index
]
else
:
matched
[
index
].
append
(
gt_box_index
)
gt_box_index
+=
1
return
matched
def
matcher_structure
(
gt_bboxes
,
pred_bboxes_rows
,
pred_bboxes
):
'''
gt_bboxes: 排序后
pred_bboxes:
'''
pre_bbox
=
gt_bboxes
[
0
]
matched
=
{}
match_bboxes_ready
=
[]
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
for
i
,
gt_box
in
enumerate
(
gt_bboxes
):
pred_distances
=
[]
for
pred_bbox
in
pred_bboxes
:
pred_distances
.
append
(
distance
(
gt_box
,
pred_bbox
))
distances
=
[]
gap_pre
=
gt_box
[
1
]
-
pre_bbox
[
1
]
gap_pre_1
=
gt_box
[
0
]
-
pre_bbox
[
2
]
#print(gap_pre, len(pred_bboxes_rows))
if
(
gap_pre_1
<
0
and
len
(
pred_bboxes_rows
)
>
0
):
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
if
len
(
pred_bboxes_rows
)
==
1
:
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
if
len
(
match_bboxes_ready
)
==
0
and
len
(
pred_bboxes_rows
)
>
0
:
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
if
len
(
match_bboxes_ready
)
==
0
and
len
(
pred_bboxes_rows
)
==
0
:
break
#print(match_bboxes_ready)
for
j
,
pred_box
in
enumerate
(
match_bboxes_ready
):
distances
.
append
(
distance
(
gt_box
,
pred_box
))
index
=
pred_distances
.
index
(
min
(
distances
))
#print(gt_box, index)
#match_bboxes_ready.pop(distances.index(min(distances)))
print
(
gt_box
,
match_bboxes_ready
[
distances
.
index
(
min
(
distances
))])
if
index
not
in
matched
.
keys
():
matched
[
index
]
=
[
i
]
else
:
matched
[
index
].
append
(
i
)
pre_bbox
=
gt_box
return
matched
def
main
():
detect_bboxes
=
json
.
load
(
open
(
'./f_detecion_bbox.json'
))
gt_bboxes
=
json
.
load
(
open
(
'./f_gt_bbox.json'
))
all_node
=
0
matched_right
=
0
key
=
'PMC4796501_003_00.png'
print
(
key
)
gt_bbox
=
gt_bboxes
[
key
]
pred_bbox
=
detect_bboxes
[
key
]
matched
=
matcher
(
gt_bbox
,
pred_bbox
)
print
(
matched
)
if
__name__
==
"__main__"
:
main
()
ppstructure/predict_system.py
View file @
79436248
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
subprocess
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
copy
import
numpy
as
np
import
time
import
tools.infer.utility
as
utility
from
tools.infer.predict_system
import
TextSystem
from
ppstructure.table.predict_table
import
TableSystem
,
to_excel
from
ppstructure.layout.predict_layout
import
LayoutDetector
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.logging
import
get_logger
logger
=
get_logger
()
def
parse_args
():
parser
=
utility
.
init_args
()
# params for table structure
parser
.
add_argument
(
"--table_max_len"
,
type
=
int
,
default
=
488
)
parser
.
add_argument
(
"--table_max_text_length"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--table_max_elem_length"
,
type
=
int
,
default
=
800
)
parser
.
add_argument
(
"--table_max_cell_num"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--table_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--table_char_type"
,
type
=
str
,
default
=
'en'
)
parser
.
add_argument
(
"--table_char_dict_path"
,
type
=
str
,
default
=
"./ppocr/utils/dict/table_structure_dict.txt"
)
# params for layout detector
parser
.
add_argument
(
"--layout_model_dir"
,
type
=
str
)
return
parser
.
parse_args
()
class
OCRSystem
():
def
__init__
(
self
,
args
):
self
.
text_system
=
TextSystem
(
args
)
self
.
table_system
=
TableSystem
(
args
)
self
.
table_layout
=
LayoutDetector
(
args
)
self
.
use_angle_cls
=
args
.
use_angle_cls
self
.
drop_score
=
args
.
drop_score
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
layout_res
=
self
.
table_layout
(
copy
.
deepcopy
(
img
))
for
region
in
layout_res
:
x1
,
y1
,
x2
,
y2
=
region
[
'bbox'
]
roi_img
=
ori_im
[
y1
:
y2
,
x1
:
x2
,:]
if
region
[
'label'
]
==
'table'
:
res
=
self
.
table_system
(
roi_img
)
else
:
res
=
self
.
text_system
(
roi_img
)
region
[
'res'
]
=
res
return
layout_res
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
image_file_list
=
image_file_list
[
args
.
process_id
::
args
.
total_process_num
]
excel_save_folder
=
'output/table'
os
.
makedirs
(
excel_save_folder
,
exist_ok
=
True
)
text_sys
=
OCRSystem
(
args
)
img_num
=
len
(
image_file_list
)
for
i
,
image_file
in
enumerate
(
image_file_list
):
logger
.
info
(
"[{}/{}] {}"
.
format
(
i
,
img_num
,
image_file
))
img
,
flag
=
check_and_read_gif
(
image_file
)
imgname
=
os
.
path
.
basename
(
image_file
).
split
(
'.'
)[
0
]
# excel_path = os.path.join(excel_save_folder, + '.xlsx')
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
starttime
=
time
.
time
()
res
=
text_sys
(
img
)
for
region
in
res
:
if
region
[
'label'
]
==
'table'
:
# x1, y1, x2, y2 = region['bbox']
excel_path
=
os
.
path
.
join
(
excel_save_folder
,
'{}_{}.xlsx'
.
format
(
imgname
,
region
[
'bbox'
]))
to_excel
(
region
[
'res'
],
excel_path
)
logger
.
info
(
res
)
elapse
=
time
.
time
()
-
starttime
logger
.
info
(
"Predict time : {:.3f}s"
.
format
(
elapse
))
if
__name__
==
"__main__"
:
args
=
parse_args
()
if
args
.
use_mp
:
p_list
=
[]
total_process_num
=
args
.
total_process_num
for
process_id
in
range
(
total_process_num
):
cmd
=
[
sys
.
executable
,
"-u"
]
+
sys
.
argv
+
[
"--process_id={}"
.
format
(
process_id
),
"--use_mp={}"
.
format
(
False
)
]
p
=
subprocess
.
Popen
(
cmd
,
stdout
=
sys
.
stdout
,
stderr
=
sys
.
stdout
)
p_list
.
append
(
p
)
for
p
in
p_list
:
p
.
wait
()
else
:
main
(
args
)
ppstructure/table/__init__.py
0 → 100644
View file @
79436248
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ppstructure/table/eval_table.py
0 → 100755
View file @
79436248
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
import
cv2
import
json
from
tqdm
import
tqdm
from
ppstructure.table.table_metric
import
TEDS
from
ppstructure.table.predict_table
import
TableSystem
,
utility
def
main
(
gt_path
,
img_root
,
args
):
teds
=
TEDS
(
n_jobs
=
16
)
text_sys
=
TableSystem
(
args
)
jsons_gt
=
json
.
load
(
open
(
gt_path
))
# gt
pred_htmls
=
[]
gt_htmls
=
[]
for
img_name
in
tqdm
(
jsons_gt
):
if
img_name
!=
'PMC1064865_002_00.png'
:
continue
# 读取信息
img
=
cv2
.
imread
(
os
.
path
.
join
(
img_root
,
img_name
))
pred_html
=
text_sys
(
img
)
pred_htmls
.
append
(
pred_html
)
gt_structures
,
gt_bboxes
,
gt_contents
,
contents_with_block
=
jsons_gt
[
img_name
]
gt_html
,
gt
=
get_gt_html
(
gt_structures
,
contents_with_block
)
# 获取HTMLgt
gt_htmls
.
append
(
gt_html
)
scores
=
teds
.
batch_evaluate_html
(
gt_htmls
,
pred_htmls
)
# 计算teds
print
(
'teds:'
,
sum
(
scores
)
/
len
(
scores
))
def
get_gt_html
(
gt_structures
,
contents_with_block
):
end_html
=
[]
td_index
=
0
for
tag
in
gt_structures
:
if
'</td>'
in
tag
:
if
contents_with_block
[
td_index
]
!=
[]:
end_html
.
extend
(
contents_with_block
[
td_index
])
end_html
.
append
(
tag
)
td_index
+=
1
else
:
end_html
.
append
(
tag
)
return
''
.
join
(
end_html
),
end_html
if
__name__
==
'__main__'
:
args
=
utility
.
parse_args
()
gt_path
=
'table/match_code/f_gt_bbox.json'
img_root
=
'table/imgs'
main
(
gt_path
,
img_root
,
args
)
ppstructure/table/predict_structure.py
0 → 100755
View file @
79436248
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
numpy
as
np
import
math
import
time
import
traceback
import
paddle
import
tools.infer.utility
as
utility
from
ppocr.data
import
create_operators
,
transform
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
logger
=
get_logger
()
class
TableStructurer
(
object
):
def
__init__
(
self
,
args
):
pre_process_list
=
[{
'ResizeTableImage'
:
{
'max_len'
:
args
.
table_max_len
}
},
{
'NormalizeImage'
:
{
'std'
:
[
0.229
,
0.224
,
0.225
],
'mean'
:
[
0.485
,
0.456
,
0.406
],
'scale'
:
'1./255.'
,
'order'
:
'hwc'
}
},
{
'PaddingTableImage'
:
None
},
{
'ToCHWImage'
:
None
},
{
'KeepKeys'
:
{
'keep_keys'
:
[
'image'
]
}
}]
postprocess_params
=
{
'name'
:
'TableLabelDecode'
,
"character_type"
:
args
.
table_char_type
,
"character_dict_path"
:
args
.
table_char_dict_path
,
"max_text_length"
:
args
.
table_max_text_length
,
"max_elem_length"
:
args
.
table_max_elem_length
,
"max_cell_num"
:
args
.
table_max_cell_num
}
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
utility
.
create_predictor
(
args
,
'table'
,
logger
)
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
data
=
{
'image'
:
img
}
data
=
transform
(
data
,
self
.
preprocess_op
)
img
=
data
[
0
]
if
img
is
None
:
return
None
,
0
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
img
=
img
.
copy
()
starttime
=
time
.
time
()
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
preds
=
{}
preds
[
'structure_probs'
]
=
outputs
[
1
]
preds
[
'loc_preds'
]
=
outputs
[
0
]
post_result
=
self
.
postprocess_op
(
preds
)
structure_str_list
=
post_result
[
'structure_str_list'
]
res_loc
=
post_result
[
'res_loc'
]
imgh
,
imgw
=
ori_im
.
shape
[
0
:
2
]
res_loc_final
=
[]
for
rno
in
range
(
len
(
res_loc
[
0
])):
x0
,
y0
,
x1
,
y1
=
res_loc
[
0
][
rno
]
left
=
max
(
int
(
imgw
*
x0
),
0
)
top
=
max
(
int
(
imgh
*
y0
),
0
)
right
=
min
(
int
(
imgw
*
x1
),
imgw
-
1
)
bottom
=
min
(
int
(
imgh
*
y1
),
imgh
-
1
)
res_loc_final
.
append
([
left
,
top
,
right
,
bottom
])
structure_str_list
=
structure_str_list
[
0
][:
-
1
]
structure_str_list
=
[
'<html>'
,
'<body>'
,
'<table>'
]
+
structure_str_list
+
[
'</table>'
,
'</body>'
,
'</html>'
]
elapse
=
time
.
time
()
-
starttime
return
(
structure_str_list
,
res_loc_final
),
elapse
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
table_structurer
=
TableStructurer
(
args
)
count
=
0
total_time
=
0
for
image_file
in
image_file_list
:
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
structure_res
,
elapse
=
table_structurer
(
img
)
logger
.
info
(
"result: {}"
.
format
(
structure_res
))
if
count
>
0
:
total_time
+=
elapse
count
+=
1
logger
.
info
(
"Predict time of {}: {}"
.
format
(
image_file
,
elapse
))
if
__name__
==
"__main__"
:
main
(
utility
.
parse_args
())
ppstructure/table/predict_table.py
0 → 100644
View file @
79436248
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
subprocess
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
copy
import
numpy
as
np
import
time
import
tools.infer.utility
as
utility
import
tools.infer.predict_rec
as
predict_rec
import
tools.infer.predict_det
as
predict_det
import
ppstructure.table.predict_structure
as
predict_strture
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.table_utils.matcher
import
distance
,
compute_iou
logger
=
get_logger
()
def
expand
(
pix
,
det_box
,
shape
):
x0
,
y0
,
x1
,
y1
=
det_box
# print(shape)
h
,
w
,
c
=
shape
tmp_x0
=
x0
-
pix
tmp_x1
=
x1
+
pix
tmp_y0
=
y0
-
pix
tmp_y1
=
y1
+
pix
x0_
=
tmp_x0
if
tmp_x0
>=
0
else
0
x1_
=
tmp_x1
if
tmp_x1
<=
w
else
w
y0_
=
tmp_y0
if
tmp_y0
>=
0
else
0
y1_
=
tmp_y1
if
tmp_y1
<=
h
else
h
return
x0_
,
y0_
,
x1_
,
y1_
class
TableSystem
(
object
):
def
__init__
(
self
,
args
):
self
.
text_detector
=
predict_det
.
TextDetector
(
args
)
self
.
text_recognizer
=
predict_rec
.
TextRecognizer
(
args
)
self
.
table_structurer
=
predict_strture
.
TableStructurer
(
args
)
self
.
use_angle_cls
=
args
.
use_angle_cls
self
.
drop_score
=
args
.
drop_score
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
structure_res
,
elapse
=
self
.
table_structurer
(
copy
.
deepcopy
(
img
))
dt_boxes
,
elapse
=
self
.
text_detector
(
copy
.
deepcopy
(
img
))
dt_boxes
=
sorted_boxes
(
dt_boxes
)
r_boxes
=
[]
for
box
in
dt_boxes
:
x_min
=
box
[:,
0
].
min
()
-
1
x_max
=
box
[:,
0
].
max
()
+
1
y_min
=
box
[:,
1
].
min
()
-
1
y_max
=
box
[:,
1
].
max
()
+
1
box
=
[
x_min
,
y_min
,
x_max
,
y_max
]
r_boxes
.
append
(
box
)
dt_boxes
=
np
.
array
(
r_boxes
)
# logger.info("dt_boxes num : {}, elapse : {}".format(
# len(dt_boxes), elapse))
if
dt_boxes
is
None
:
return
None
,
None
img_crop_list
=
[]
for
i
in
range
(
len
(
dt_boxes
)):
det_box
=
dt_boxes
[
i
]
x0
,
y0
,
x1
,
y1
=
expand
(
2
,
det_box
,
ori_im
.
shape
)
text_rect
=
ori_im
[
int
(
y0
):
int
(
y1
),
int
(
x0
):
int
(
x1
),
:]
img_crop_list
.
append
(
text_rect
)
rec_res
,
elapse
=
self
.
text_recognizer
(
img_crop_list
)
# logger.info("rec_res num : {}, elapse : {}".format(
# len(rec_res), elapse))
pred_html
,
pred
=
self
.
rebuild_table
(
structure_res
,
dt_boxes
,
rec_res
)
return
pred_html
def
rebuild_table
(
self
,
structure_res
,
dt_boxes
,
rec_res
):
pred_structures
,
pred_bboxes
=
structure_res
matched_index
=
self
.
match_result
(
dt_boxes
,
pred_bboxes
)
pred_html
,
pred
=
self
.
get_pred_html
(
pred_structures
,
matched_index
,
rec_res
)
return
pred_html
,
pred
def
match_result
(
self
,
dt_boxes
,
pred_bboxes
):
matched
=
{}
for
i
,
gt_box
in
enumerate
(
dt_boxes
):
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
distances
=
[]
for
j
,
pred_box
in
enumerate
(
pred_bboxes
):
distances
.
append
(
(
distance
(
gt_box
,
pred_box
),
1.
-
compute_iou
(
gt_box
,
pred_box
)))
# 获取两两cell之间的L1距离和 1- IOU
sorted_distances
=
distances
.
copy
()
# 根据距离和IOU挑选最"近"的cell
sorted_distances
=
sorted
(
sorted_distances
,
key
=
lambda
item
:
(
item
[
1
],
item
[
0
]))
if
distances
.
index
(
sorted_distances
[
0
])
not
in
matched
.
keys
():
matched
[
distances
.
index
(
sorted_distances
[
0
])]
=
[
i
]
else
:
matched
[
distances
.
index
(
sorted_distances
[
0
])].
append
(
i
)
return
matched
def
get_pred_html
(
self
,
pred_structures
,
matched_index
,
ocr_contents
):
end_html
=
[]
td_index
=
0
for
tag
in
pred_structures
:
if
'</td>'
in
tag
:
if
td_index
in
matched_index
.
keys
():
b_with
=
False
if
'<b>'
in
ocr_contents
[
matched_index
[
td_index
][
0
]]
and
len
(
matched_index
[
td_index
])
>
1
:
b_with
=
True
end_html
.
extend
(
'<b>'
)
for
i
,
td_index_index
in
enumerate
(
matched_index
[
td_index
]):
content
=
ocr_contents
[
td_index_index
][
0
]
if
len
(
matched_index
[
td_index
])
>
1
:
if
len
(
content
)
==
0
:
continue
if
content
[
0
]
==
' '
:
content
=
content
[
1
:]
if
'<b>'
in
content
:
content
=
content
[
3
:]
if
'</b>'
in
content
:
content
=
content
[:
-
4
]
if
len
(
content
)
==
0
:
continue
if
i
!=
len
(
matched_index
[
td_index
])
-
1
and
' '
!=
content
[
-
1
]:
content
+=
' '
end_html
.
extend
(
content
)
if
b_with
:
end_html
.
extend
(
'</b>'
)
end_html
.
append
(
tag
)
td_index
+=
1
else
:
end_html
.
append
(
tag
)
return
''
.
join
(
end_html
),
end_html
def
sorted_boxes
(
dt_boxes
):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes
=
dt_boxes
.
shape
[
0
]
sorted_boxes
=
sorted
(
dt_boxes
,
key
=
lambda
x
:
(
x
[
0
][
1
],
x
[
0
][
0
]))
_boxes
=
list
(
sorted_boxes
)
for
i
in
range
(
num_boxes
-
1
):
if
abs
(
_boxes
[
i
+
1
][
0
][
1
]
-
_boxes
[
i
][
0
][
1
])
<
10
and
\
(
_boxes
[
i
+
1
][
0
][
0
]
<
_boxes
[
i
][
0
][
0
]):
tmp
=
_boxes
[
i
]
_boxes
[
i
]
=
_boxes
[
i
+
1
]
_boxes
[
i
+
1
]
=
tmp
return
_boxes
def
to_excel
(
html_table
,
excel_path
):
from
tablepyxl
import
tablepyxl
tablepyxl
.
document_to_xl
(
html_table
,
excel_path
)
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
image_file_list
=
image_file_list
[
args
.
process_id
::
args
.
total_process_num
]
excel_save_folder
=
'output/table'
os
.
makedirs
(
excel_save_folder
,
exist_ok
=
True
)
text_sys
=
TableSystem
(
args
)
img_num
=
len
(
image_file_list
)
for
i
,
image_file
in
enumerate
(
image_file_list
):
logger
.
info
(
"[{}/{}] {}"
.
format
(
i
,
img_num
,
image_file
))
img
,
flag
=
check_and_read_gif
(
image_file
)
excel_path
=
os
.
path
.
join
(
excel_save_folder
,
os
.
path
.
basename
(
image_file
).
split
(
'.'
)[
0
]
+
'.xlsx'
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
starttime
=
time
.
time
()
pred_html
=
text_sys
(
img
)
to_excel
(
pred_html
,
excel_path
)
logger
.
info
(
'excel saved to {}'
.
format
(
excel_path
))
logger
.
info
(
pred_html
)
elapse
=
time
.
time
()
-
starttime
logger
.
info
(
"Predict time : {:.3f}s"
.
format
(
elapse
))
if
__name__
==
"__main__"
:
args
=
utility
.
parse_args
()
if
args
.
use_mp
:
p_list
=
[]
total_process_num
=
args
.
total_process_num
for
process_id
in
range
(
total_process_num
):
cmd
=
[
sys
.
executable
,
"-u"
]
+
sys
.
argv
+
[
"--process_id={}"
.
format
(
process_id
),
"--use_mp={}"
.
format
(
False
)
]
p
=
subprocess
.
Popen
(
cmd
,
stdout
=
sys
.
stdout
,
stderr
=
sys
.
stdout
)
p_list
.
append
(
p
)
for
p
in
p_list
:
p
.
wait
()
else
:
main
(
args
)
ppstructure/table/table_metric/__init__.py
0 → 100755
View file @
79436248
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__
=
[
'TEDS'
]
from
.table_metric
import
TEDS
\ No newline at end of file
ppstructure/table/table_metric/parallel.py
0 → 100755
View file @
79436248
from
tqdm
import
tqdm
from
concurrent.futures
import
ProcessPoolExecutor
,
as_completed
def
parallel_process
(
array
,
function
,
n_jobs
=
16
,
use_kwargs
=
False
,
front_num
=
0
):
"""
A parallel version of the map function with a progress bar.
Args:
array (array-like): An array to iterate over.
function (function): A python function to apply to the elements of array
n_jobs (int, default=16): The number of cores to use
use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
keyword arguments to function
front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
Useful for catching bugs
Returns:
[function(array[0]), function(array[1]), ...]
"""
# We run the first few iterations serially to catch bugs
if
front_num
>
0
:
front
=
[
function
(
**
a
)
if
use_kwargs
else
function
(
a
)
for
a
in
array
[:
front_num
]]
else
:
front
=
[]
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
if
n_jobs
==
1
:
return
front
+
[
function
(
**
a
)
if
use_kwargs
else
function
(
a
)
for
a
in
tqdm
(
array
[
front_num
:])]
# Assemble the workers
with
ProcessPoolExecutor
(
max_workers
=
n_jobs
)
as
pool
:
# Pass the elements of array into function
if
use_kwargs
:
futures
=
[
pool
.
submit
(
function
,
**
a
)
for
a
in
array
[
front_num
:]]
else
:
futures
=
[
pool
.
submit
(
function
,
a
)
for
a
in
array
[
front_num
:]]
kwargs
=
{
'total'
:
len
(
futures
),
'unit'
:
'it'
,
'unit_scale'
:
True
,
'leave'
:
True
}
# Print out the progress as tasks complete
for
f
in
tqdm
(
as_completed
(
futures
),
**
kwargs
):
pass
out
=
[]
# Get the results from the futures.
for
i
,
future
in
tqdm
(
enumerate
(
futures
)):
try
:
out
.
append
(
future
.
result
())
except
Exception
as
e
:
out
.
append
(
e
)
return
front
+
out
tools/infer/utility.py
View file @
79436248
...
@@ -125,6 +125,8 @@ def create_predictor(args, mode, logger):
...
@@ -125,6 +125,8 @@ def create_predictor(args, mode, logger):
model_dir
=
args
.
cls_model_dir
model_dir
=
args
.
cls_model_dir
elif
mode
==
'rec'
:
elif
mode
==
'rec'
:
model_dir
=
args
.
rec_model_dir
model_dir
=
args
.
rec_model_dir
elif
mode
==
'table'
:
model_dir
=
args
.
table_model_dir
else
:
else
:
model_dir
=
args
.
e2e_model_dir
model_dir
=
args
.
e2e_model_dir
...
@@ -244,7 +246,8 @@ def create_predictor(args, mode, logger):
...
@@ -244,7 +246,8 @@ def create_predictor(args, mode, logger):
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
config
.
switch_use_feed_fetch_ops
(
False
)
config
.
switch_use_feed_fetch_ops
(
False
)
if
mode
==
'table'
:
config
.
switch_ir_optim
(
False
)
# create predictor
# create predictor
predictor
=
inference
.
create_predictor
(
config
)
predictor
=
inference
.
create_predictor
(
config
)
input_names
=
predictor
.
get_input_names
()
input_names
=
predictor
.
get_input_names
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment