Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
79436248
Commit
79436248
authored
Jun 03, 2021
by
WenmuZhou
Browse files
add table eval and predict script
parent
ad4853db
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
4060 additions
and
4 deletions
+4060
-4
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+172
-3
ppocr/utils/dict/table_dict.txt
ppocr/utils/dict/table_dict.txt
+278
-0
ppocr/utils/dict/table_structure_dict.txt
ppocr/utils/dict/table_structure_dict.txt
+2759
-0
ppocr/utils/table_utils/matcher.py
ppocr/utils/table_utils/matcher.py
+214
-0
ppstructure/predict_system.py
ppstructure/predict_system.py
+123
-0
ppstructure/table/__init__.py
ppstructure/table/__init__.py
+13
-0
ppstructure/table/eval_table.py
ppstructure/table/eval_table.py
+67
-0
ppstructure/table/predict_structure.py
ppstructure/table/predict_structure.py
+141
-0
ppstructure/table/predict_table.py
ppstructure/table/predict_table.py
+222
-0
ppstructure/table/table_metric/__init__.py
ppstructure/table/table_metric/__init__.py
+16
-0
ppstructure/table/table_metric/parallel.py
ppstructure/table/table_metric/parallel.py
+51
-0
tools/infer/utility.py
tools/infer/utility.py
+4
-1
No files found.
ppocr/postprocess/rec_postprocess.py
View file @
79436248
...
...
@@ -44,16 +44,16 @@ class BaseRecLabelDecode(object):
self
.
character_str
=
string
.
printable
[:
-
6
]
dict_character
=
list
(
self
.
character_str
)
elif
character_type
in
support_character_type
:
self
.
character_str
=
""
self
.
character_str
=
[]
assert
character_dict_path
is
not
None
,
"character_dict_path should not be None when character_type is {}"
.
format
(
character_type
)
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
for
line
in
lines
:
line
=
line
.
decode
(
'utf-8'
).
strip
(
"
\n
"
).
strip
(
"
\r\n
"
)
self
.
character_str
+=
line
self
.
character_str
.
append
(
line
)
if
use_space_char
:
self
.
character_str
+=
" "
self
.
character_str
.
append
(
" "
)
dict_character
=
list
(
self
.
character_str
)
else
:
...
...
@@ -288,3 +288,172 @@ class SRNLabelDecode(BaseRecLabelDecode):
assert
False
,
"unsupport type %s in get_beg_end_flag_idx"
\
%
beg_or_end
return
idx
class
TableLabelDecode
(
object
):
""" """
def
__init__
(
self
,
max_text_length
,
max_elem_length
,
max_cell_num
,
character_dict_path
,
**
kwargs
):
self
.
max_text_length
=
max_text_length
self
.
max_elem_length
=
max_elem_length
self
.
max_cell_num
=
max_cell_num
list_character
,
list_elem
=
self
.
load_char_elem_dict
(
character_dict_path
)
list_character
=
self
.
add_special_char
(
list_character
)
list_elem
=
self
.
add_special_char
(
list_elem
)
self
.
dict_character
=
{}
self
.
dict_idx_character
=
{}
for
i
,
char
in
enumerate
(
list_character
):
self
.
dict_idx_character
[
i
]
=
char
self
.
dict_character
[
char
]
=
i
self
.
dict_elem
=
{}
self
.
dict_idx_elem
=
{}
for
i
,
elem
in
enumerate
(
list_elem
):
self
.
dict_idx_elem
[
i
]
=
elem
self
.
dict_elem
[
elem
]
=
i
def
load_char_elem_dict
(
self
,
character_dict_path
):
list_character
=
[]
list_elem
=
[]
with
open
(
character_dict_path
,
"rb"
)
as
fin
:
lines
=
fin
.
readlines
()
substr
=
lines
[
0
].
decode
(
'utf-8'
).
strip
(
"
\n
"
).
split
(
"
\t
"
)
character_num
=
int
(
substr
[
0
])
elem_num
=
int
(
substr
[
1
])
for
cno
in
range
(
1
,
1
+
character_num
):
character
=
lines
[
cno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
)
list_character
.
append
(
character
)
for
eno
in
range
(
1
+
character_num
,
1
+
character_num
+
elem_num
):
elem
=
lines
[
eno
].
decode
(
'utf-8'
).
strip
(
"
\n
"
)
list_elem
.
append
(
elem
)
return
list_character
,
list_elem
def
add_special_char
(
self
,
list_character
):
self
.
beg_str
=
"sos"
self
.
end_str
=
"eos"
list_character
=
[
self
.
beg_str
]
+
list_character
+
[
self
.
end_str
]
return
list_character
def
get_sp_tokens
(
self
):
char_beg_idx
=
self
.
get_beg_end_flag_idx
(
'beg'
,
'char'
)
char_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'char'
)
elem_beg_idx
=
self
.
get_beg_end_flag_idx
(
'beg'
,
'elem'
)
elem_end_idx
=
self
.
get_beg_end_flag_idx
(
'end'
,
'elem'
)
elem_char_idx1
=
self
.
dict_elem
[
'<td>'
]
elem_char_idx2
=
self
.
dict_elem
[
'<td'
]
sp_tokens
=
np
.
array
([
char_beg_idx
,
char_end_idx
,
elem_beg_idx
,
elem_end_idx
,
elem_char_idx1
,
elem_char_idx2
,
self
.
max_text_length
,
self
.
max_elem_length
,
self
.
max_cell_num
])
return
sp_tokens
def
__call__
(
self
,
preds
):
structure_probs
=
preds
[
'structure_probs'
]
loc_preds
=
preds
[
'loc_preds'
]
if
isinstance
(
structure_probs
,
paddle
.
Tensor
):
structure_probs
=
structure_probs
.
numpy
()
if
isinstance
(
loc_preds
,
paddle
.
Tensor
):
loc_preds
=
loc_preds
.
numpy
()
structure_idx
=
structure_probs
.
argmax
(
axis
=
2
)
structure_probs
=
structure_probs
.
max
(
axis
=
2
)
structure_str
,
structure_pos
,
result_score_list
,
result_elem_idx_list
=
self
.
decode
(
structure_idx
,
structure_probs
,
'elem'
)
res_html_code_list
=
[]
res_loc_list
=
[]
batch_num
=
len
(
structure_str
)
for
bno
in
range
(
batch_num
):
res_loc
=
[]
for
sno
in
range
(
len
(
structure_str
[
bno
])):
text
=
structure_str
[
bno
][
sno
]
if
text
in
[
'<td>'
,
'<td'
]:
pos
=
structure_pos
[
bno
][
sno
]
res_loc
.
append
(
loc_preds
[
bno
,
pos
])
res_html_code
=
''
.
join
(
structure_str
[
bno
])
res_loc
=
np
.
array
(
res_loc
)
res_html_code_list
.
append
(
res_html_code
)
res_loc_list
.
append
(
res_loc
)
return
{
'res_html_code'
:
res_html_code_list
,
'res_loc'
:
res_loc_list
,
'res_score_list'
:
result_score_list
,
'res_elem_idx_list'
:
result_elem_idx_list
,
'structure_str_list'
:
structure_str
}
def
decode
(
self
,
text_index
,
structure_probs
,
char_or_elem
):
"""convert text-label into text-index.
"""
if
char_or_elem
==
"char"
:
max_len
=
self
.
max_text_length
current_dict
=
self
.
dict_idx_character
else
:
max_len
=
self
.
max_elem_length
current_dict
=
self
.
dict_idx_elem
ignored_tokens
=
self
.
get_ignored_tokens
(
'elem'
)
beg_idx
,
end_idx
=
ignored_tokens
# select_td_tokens = []
# select_span_tokens = []
# for elem in self.dict_elem:
# # if elem == '<td>' or elem == '<td' or elem == '<tr>'\
# # or 'rowspan' in elem or 'colspan' in elem:
# if elem == '<td>' or elem == '<td' or elem == '<tr>':
# select_td_tokens.append(self.dict_elem[elem])
# if 'rowspan' in elem or 'colspan' in elem:
# select_span_tokens.append(self.dict_elem[elem])
result_list
=
[]
result_pos_list
=
[]
result_score_list
=
[]
result_elem_idx_list
=
[]
batch_size
=
len
(
text_index
)
for
batch_idx
in
range
(
batch_size
):
char_list
=
[]
elem_pos_list
=
[]
elem_idx_list
=
[]
score_list
=
[]
for
idx
in
range
(
len
(
text_index
[
batch_idx
])):
tmp_elem_idx
=
int
(
text_index
[
batch_idx
][
idx
])
if
idx
>
0
and
tmp_elem_idx
==
end_idx
:
break
if
tmp_elem_idx
in
ignored_tokens
:
continue
# if tmp_elem_idx in select_td_tokens:
# total_td_score += structure_probs[batch_idx, idx]
# total_td_num += 1
# if tmp_elem_idx in select_span_tokens:
# total_span_score += structure_probs[batch_idx, idx]
# total_span_num += 1
char_list
.
append
(
current_dict
[
tmp_elem_idx
])
elem_pos_list
.
append
(
idx
)
score_list
.
append
(
structure_probs
[
batch_idx
,
idx
])
elem_idx_list
.
append
(
tmp_elem_idx
)
result_list
.
append
(
char_list
)
result_pos_list
.
append
(
elem_pos_list
)
result_score_list
.
append
(
score_list
)
result_elem_idx_list
.
append
(
elem_idx_list
)
return
result_list
,
result_pos_list
,
result_score_list
,
result_elem_idx_list
def
get_ignored_tokens
(
self
,
char_or_elem
):
beg_idx
=
self
.
get_beg_end_flag_idx
(
"beg"
,
char_or_elem
)
end_idx
=
self
.
get_beg_end_flag_idx
(
"end"
,
char_or_elem
)
return
[
beg_idx
,
end_idx
]
def
get_beg_end_flag_idx
(
self
,
beg_or_end
,
char_or_elem
):
if
char_or_elem
==
"char"
:
if
beg_or_end
==
"beg"
:
idx
=
self
.
dict_character
[
self
.
beg_str
]
elif
beg_or_end
==
"end"
:
idx
=
self
.
dict_character
[
self
.
end_str
]
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of char"
\
%
beg_or_end
elif
char_or_elem
==
"elem"
:
if
beg_or_end
==
"beg"
:
idx
=
self
.
dict_elem
[
self
.
beg_str
]
elif
beg_or_end
==
"end"
:
idx
=
self
.
dict_elem
[
self
.
end_str
]
else
:
assert
False
,
"Unsupport type %s in get_beg_end_flag_idx of elem"
\
%
beg_or_end
else
:
assert
False
,
"Unsupport type %s in char_or_elem"
\
%
char_or_elem
return
idx
ppocr/utils/dict/table_dict.txt
0 → 100644
View file @
79436248
←
</overline>
☆
─
α
⋅
$
ω
ψ
χ
(
υ
≥
σ
,
ρ
ε
0
■
4
8
✗
b
<
✓
Ψ
Ω
€
D
3
Π
H
║
</
>
L
Φ
Χ
θ
P
κ
λ
μ
T
ξ
X
β
γ
δ
\
ζ
η
`
d
<strike>
h
f
l
Θ
p
√
t
</sub>
x
Β
Γ
Δ
|
ǂ
ɛ
j
̧
➢
̌
′
«
△
▲
#
</b>
'
Ι
+
¶
/
▼
⇑
□
·
7
▪
;
?
➔
∩
C
÷
G
⇒
K
<sup>
O
S
С
W
Α
[
○
_
●
‡
c
z
g
<i>
o
<sub>
〈
〉
s
⩽
w
φ
ʹ
{
»
∣
̆
e
ˆ
∈
τ
◆
ι
∅
∆
∙
∘
Ø
ß
✔
∞
∑
−
×
◊
∗
∖
˃
˂
∫
"
i
&
π
↔
*
∥
æ
∧
.
⁄
ø
Q
∼
6
⁎
:
★
>
a
B
≈
F
J
̄
N
♯
R
V
<overline>
―
Z
♣
^
¤
¥
§
<underline>
¢
£
≦
≤
‖
Λ
©
n
↓
→
↑
r
°
±
v
<b>
♂
k
♀
~
ᅟ
̇
@
”
♦
ł
®
⊕
„
!
</sup>
%
⇓
)
-
1
5
9
=
А
A
‰
⋆
Σ
E
◦
I
※
M
m
̨
⩾
†
</i>
•
U
Y
]
̸
2
‐
–
‒
̂
—
̀
́
’
‘
⋮
⋯
̊
“
̈
≧
q
u
ı
y
</underline>
̃
}
ν
ppocr/utils/dict/table_structure_dict.txt
0 → 100644
View file @
79436248
This diff is collapsed.
Click to expand it.
ppocr/utils/table_utils/matcher.py
0 → 100755
View file @
79436248
import
json
def
distance
(
box_1
,
box_2
):
x1
,
y1
,
x2
,
y2
=
box_1
x3
,
y3
,
x4
,
y4
=
box_2
# min_x = (x1 + x2) / 2
# min_y = (y1 + y2) / 2
# max_x = (x3 + x4) / 2
# max_y = (y3 + y4) / 2
dis
=
abs
(
x3
-
x1
)
+
abs
(
y3
-
y1
)
+
abs
(
x4
-
x2
)
+
abs
(
y4
-
y2
)
dis_2
=
abs
(
x3
-
x1
)
+
abs
(
y3
-
y1
)
dis_3
=
abs
(
x4
-
x2
)
+
abs
(
y4
-
y2
)
#dis = pow(min_x - max_x, 2) + pow(min_y - max_y, 2) + pow(x3 - x1, 2) + pow(y3 - y1, 2) + pow(x4- x2, 2) + pow(y4 - y2, 2) + abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
return
dis
+
min
(
dis_2
,
dis_3
)
def
compute_iou
(
rec1
,
rec2
):
"""
computing IoU
:param rec1: (y0, x0, y1, x1), which reflects
(top, left, bottom, right)
:param rec2: (y0, x0, y1, x1)
:return: scala value of IoU
"""
# computing area of each rectangles
rec1
,
rec2
=
rec1
*
1000
,
rec2
*
1000
S_rec1
=
(
rec1
[
2
]
-
rec1
[
0
])
*
(
rec1
[
3
]
-
rec1
[
1
])
S_rec2
=
(
rec2
[
2
]
-
rec2
[
0
])
*
(
rec2
[
3
]
-
rec2
[
1
])
# computing the sum_area
sum_area
=
S_rec1
+
S_rec2
# find the each edge of intersect rectangle
left_line
=
max
(
rec1
[
1
],
rec2
[
1
])
right_line
=
min
(
rec1
[
3
],
rec2
[
3
])
top_line
=
max
(
rec1
[
0
],
rec2
[
0
])
bottom_line
=
min
(
rec1
[
2
],
rec2
[
2
])
# judge if there is an intersect
if
left_line
>=
right_line
or
top_line
>=
bottom_line
:
return
0
else
:
intersect
=
(
right_line
-
left_line
)
*
(
bottom_line
-
top_line
)
return
(
intersect
/
(
sum_area
-
intersect
))
*
1.0
def
matcher_merge
(
ocr_bboxes
,
pred_bboxes
):
# ocr_bboxes: OCR pred_bboxes:端到端
all_dis
=
[]
ious
=
[]
matched
=
{}
for
i
,
gt_box
in
enumerate
(
ocr_bboxes
):
distances
=
[]
for
j
,
pred_box
in
enumerate
(
pred_bboxes
):
distances
.
append
((
distance
(
gt_box
,
pred_box
),
1.
-
compute_iou
(
gt_box
,
pred_box
)))
#获取两两cell之间的L1距离和 1- IOU
sorted_distances
=
distances
.
copy
()
# 根据距离和IOU挑选最"近"的cell
sorted_distances
=
sorted
(
sorted_distances
,
key
=
lambda
item
:
(
item
[
1
],
item
[
0
]))
if
distances
.
index
(
sorted_distances
[
0
])
not
in
matched
.
keys
():
matched
[
distances
.
index
(
sorted_distances
[
0
])]
=
[
i
]
else
:
matched
[
distances
.
index
(
sorted_distances
[
0
])].
append
(
i
)
return
matched
#, sum(ious) / len(ious)
def
complex_num
(
pred_bboxes
):
complex_nums
=
[]
for
bbox
in
pred_bboxes
:
distances
=
[]
temp_ious
=
[]
for
pred_bbox
in
pred_bboxes
:
if
bbox
!=
pred_bbox
:
distances
.
append
(
distance
(
bbox
,
pred_bbox
))
temp_ious
.
append
(
compute_iou
(
bbox
,
pred_bbox
))
complex_nums
.
append
(
temp_ious
[
distances
.
index
(
min
(
distances
))])
return
sum
(
complex_nums
)
/
len
(
complex_nums
)
def
get_rows
(
pred_bboxes
):
pre_bbox
=
pred_bboxes
[
0
]
res
=
[]
step
=
0
for
i
in
range
(
len
(
pred_bboxes
)):
bbox
=
pred_bboxes
[
i
]
if
bbox
[
1
]
-
pre_bbox
[
1
]
>
2
or
bbox
[
0
]
-
pre_bbox
[
0
]
<
0
:
break
else
:
res
.
append
(
bbox
)
step
+=
1
for
i
in
range
(
step
):
pred_bboxes
.
pop
(
0
)
return
res
,
pred_bboxes
def
refine_rows
(
pred_bboxes
):
# 微调整行的框,使在一条水平线上
ys_1
=
[]
ys_2
=
[]
for
box
in
pred_bboxes
:
ys_1
.
append
(
box
[
1
])
ys_2
.
append
(
box
[
3
])
min_y_1
=
sum
(
ys_1
)
/
len
(
ys_1
)
min_y_2
=
sum
(
ys_2
)
/
len
(
ys_2
)
re_boxes
=
[]
for
box
in
pred_bboxes
:
box
[
1
]
=
min_y_1
box
[
3
]
=
min_y_2
re_boxes
.
append
(
box
)
return
re_boxes
def
matcher_refine_row
(
gt_bboxes
,
pred_bboxes
):
before_refine_pred_bboxes
=
pred_bboxes
.
copy
()
pred_bboxes
=
[]
while
(
len
(
before_refine_pred_bboxes
)
!=
0
):
row_bboxes
,
before_refine_pred_bboxes
=
get_rows
(
before_refine_pred_bboxes
)
print
(
row_bboxes
)
pred_bboxes
.
extend
(
refine_rows
(
row_bboxes
))
all_dis
=
[]
ious
=
[]
matched
=
{}
for
i
,
gt_box
in
enumerate
(
gt_bboxes
):
distances
=
[]
#temp_ious = []
for
j
,
pred_box
in
enumerate
(
pred_bboxes
):
distances
.
append
(
distance
(
gt_box
,
pred_box
))
#temp_ious.append(compute_iou(gt_box, pred_box))
#all_dis.append(min(distances))
#ious.append(temp_ious[distances.index(min(distances))])
if
distances
.
index
(
min
(
distances
))
not
in
matched
.
keys
():
matched
[
distances
.
index
(
min
(
distances
))]
=
[
i
]
else
:
matched
[
distances
.
index
(
min
(
distances
))].
append
(
i
)
return
matched
#, sum(ious) / len(ious)
#先挑选出一行,再进行匹配
def
matcher_structure_1
(
gt_bboxes
,
pred_bboxes_rows
,
pred_bboxes
):
gt_box_index
=
0
delete_gt_bboxes
=
gt_bboxes
.
copy
()
match_bboxes_ready
=
[]
matched
=
{}
while
(
len
(
delete_gt_bboxes
)
!=
0
):
row_bboxes
,
delete_gt_bboxes
=
get_rows
(
delete_gt_bboxes
)
row_bboxes
=
sorted
(
row_bboxes
,
key
=
lambda
key
:
key
[
0
])
if
len
(
pred_bboxes_rows
)
>
0
:
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
print
(
row_bboxes
)
for
i
,
gt_box
in
enumerate
(
row_bboxes
):
#print(gt_box)
pred_distances
=
[]
distances
=
[]
for
pred_bbox
in
pred_bboxes
:
pred_distances
.
append
(
distance
(
gt_box
,
pred_bbox
))
for
j
,
pred_box
in
enumerate
(
match_bboxes_ready
):
distances
.
append
(
distance
(
gt_box
,
pred_box
))
index
=
pred_distances
.
index
(
min
(
distances
))
#print('index', index)
if
index
not
in
matched
.
keys
():
matched
[
index
]
=
[
gt_box_index
]
else
:
matched
[
index
].
append
(
gt_box_index
)
gt_box_index
+=
1
return
matched
def
matcher_structure
(
gt_bboxes
,
pred_bboxes_rows
,
pred_bboxes
):
'''
gt_bboxes: 排序后
pred_bboxes:
'''
pre_bbox
=
gt_bboxes
[
0
]
matched
=
{}
match_bboxes_ready
=
[]
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
for
i
,
gt_box
in
enumerate
(
gt_bboxes
):
pred_distances
=
[]
for
pred_bbox
in
pred_bboxes
:
pred_distances
.
append
(
distance
(
gt_box
,
pred_bbox
))
distances
=
[]
gap_pre
=
gt_box
[
1
]
-
pre_bbox
[
1
]
gap_pre_1
=
gt_box
[
0
]
-
pre_bbox
[
2
]
#print(gap_pre, len(pred_bboxes_rows))
if
(
gap_pre_1
<
0
and
len
(
pred_bboxes_rows
)
>
0
):
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
if
len
(
pred_bboxes_rows
)
==
1
:
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
if
len
(
match_bboxes_ready
)
==
0
and
len
(
pred_bboxes_rows
)
>
0
:
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
if
len
(
match_bboxes_ready
)
==
0
and
len
(
pred_bboxes_rows
)
==
0
:
break
#print(match_bboxes_ready)
for
j
,
pred_box
in
enumerate
(
match_bboxes_ready
):
distances
.
append
(
distance
(
gt_box
,
pred_box
))
index
=
pred_distances
.
index
(
min
(
distances
))
#print(gt_box, index)
#match_bboxes_ready.pop(distances.index(min(distances)))
print
(
gt_box
,
match_bboxes_ready
[
distances
.
index
(
min
(
distances
))])
if
index
not
in
matched
.
keys
():
matched
[
index
]
=
[
i
]
else
:
matched
[
index
].
append
(
i
)
pre_bbox
=
gt_box
return
matched
def
main
():
detect_bboxes
=
json
.
load
(
open
(
'./f_detecion_bbox.json'
))
gt_bboxes
=
json
.
load
(
open
(
'./f_gt_bbox.json'
))
all_node
=
0
matched_right
=
0
key
=
'PMC4796501_003_00.png'
print
(
key
)
gt_bbox
=
gt_bboxes
[
key
]
pred_bbox
=
detect_bboxes
[
key
]
matched
=
matcher
(
gt_bbox
,
pred_bbox
)
print
(
matched
)
if
__name__
==
"__main__"
:
main
()
ppstructure/predict_system.py
View file @
79436248
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
subprocess
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
copy
import
numpy
as
np
import
time
import
tools.infer.utility
as
utility
from
tools.infer.predict_system
import
TextSystem
from
ppstructure.table.predict_table
import
TableSystem
,
to_excel
from
ppstructure.layout.predict_layout
import
LayoutDetector
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.logging
import
get_logger
logger
=
get_logger
()
def
parse_args
():
parser
=
utility
.
init_args
()
# params for table structure
parser
.
add_argument
(
"--table_max_len"
,
type
=
int
,
default
=
488
)
parser
.
add_argument
(
"--table_max_text_length"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--table_max_elem_length"
,
type
=
int
,
default
=
800
)
parser
.
add_argument
(
"--table_max_cell_num"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--table_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--table_char_type"
,
type
=
str
,
default
=
'en'
)
parser
.
add_argument
(
"--table_char_dict_path"
,
type
=
str
,
default
=
"./ppocr/utils/dict/table_structure_dict.txt"
)
# params for layout detector
parser
.
add_argument
(
"--layout_model_dir"
,
type
=
str
)
return
parser
.
parse_args
()
class
OCRSystem
():
def
__init__
(
self
,
args
):
self
.
text_system
=
TextSystem
(
args
)
self
.
table_system
=
TableSystem
(
args
)
self
.
table_layout
=
LayoutDetector
(
args
)
self
.
use_angle_cls
=
args
.
use_angle_cls
self
.
drop_score
=
args
.
drop_score
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
layout_res
=
self
.
table_layout
(
copy
.
deepcopy
(
img
))
for
region
in
layout_res
:
x1
,
y1
,
x2
,
y2
=
region
[
'bbox'
]
roi_img
=
ori_im
[
y1
:
y2
,
x1
:
x2
,:]
if
region
[
'label'
]
==
'table'
:
res
=
self
.
table_system
(
roi_img
)
else
:
res
=
self
.
text_system
(
roi_img
)
region
[
'res'
]
=
res
return
layout_res
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
image_file_list
=
image_file_list
[
args
.
process_id
::
args
.
total_process_num
]
excel_save_folder
=
'output/table'
os
.
makedirs
(
excel_save_folder
,
exist_ok
=
True
)
text_sys
=
OCRSystem
(
args
)
img_num
=
len
(
image_file_list
)
for
i
,
image_file
in
enumerate
(
image_file_list
):
logger
.
info
(
"[{}/{}] {}"
.
format
(
i
,
img_num
,
image_file
))
img
,
flag
=
check_and_read_gif
(
image_file
)
imgname
=
os
.
path
.
basename
(
image_file
).
split
(
'.'
)[
0
]
# excel_path = os.path.join(excel_save_folder, + '.xlsx')
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
starttime
=
time
.
time
()
res
=
text_sys
(
img
)
for
region
in
res
:
if
region
[
'label'
]
==
'table'
:
# x1, y1, x2, y2 = region['bbox']
excel_path
=
os
.
path
.
join
(
excel_save_folder
,
'{}_{}.xlsx'
.
format
(
imgname
,
region
[
'bbox'
]))
to_excel
(
region
[
'res'
],
excel_path
)
logger
.
info
(
res
)
elapse
=
time
.
time
()
-
starttime
logger
.
info
(
"Predict time : {:.3f}s"
.
format
(
elapse
))
if
__name__
==
"__main__"
:
args
=
parse_args
()
if
args
.
use_mp
:
p_list
=
[]
total_process_num
=
args
.
total_process_num
for
process_id
in
range
(
total_process_num
):
cmd
=
[
sys
.
executable
,
"-u"
]
+
sys
.
argv
+
[
"--process_id={}"
.
format
(
process_id
),
"--use_mp={}"
.
format
(
False
)
]
p
=
subprocess
.
Popen
(
cmd
,
stdout
=
sys
.
stdout
,
stderr
=
sys
.
stdout
)
p_list
.
append
(
p
)
for
p
in
p_list
:
p
.
wait
()
else
:
main
(
args
)
ppstructure/table/__init__.py
0 → 100644
View file @
79436248
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ppstructure/table/eval_table.py
0 → 100755
View file @
79436248
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
import
cv2
import
json
from
tqdm
import
tqdm
from
ppstructure.table.table_metric
import
TEDS
from
ppstructure.table.predict_table
import
TableSystem
,
utility
def
main
(
gt_path
,
img_root
,
args
):
teds
=
TEDS
(
n_jobs
=
16
)
text_sys
=
TableSystem
(
args
)
jsons_gt
=
json
.
load
(
open
(
gt_path
))
# gt
pred_htmls
=
[]
gt_htmls
=
[]
for
img_name
in
tqdm
(
jsons_gt
):
if
img_name
!=
'PMC1064865_002_00.png'
:
continue
# 读取信息
img
=
cv2
.
imread
(
os
.
path
.
join
(
img_root
,
img_name
))
pred_html
=
text_sys
(
img
)
pred_htmls
.
append
(
pred_html
)
gt_structures
,
gt_bboxes
,
gt_contents
,
contents_with_block
=
jsons_gt
[
img_name
]
gt_html
,
gt
=
get_gt_html
(
gt_structures
,
contents_with_block
)
# 获取HTMLgt
gt_htmls
.
append
(
gt_html
)
scores
=
teds
.
batch_evaluate_html
(
gt_htmls
,
pred_htmls
)
# 计算teds
print
(
'teds:'
,
sum
(
scores
)
/
len
(
scores
))
def
get_gt_html
(
gt_structures
,
contents_with_block
):
end_html
=
[]
td_index
=
0
for
tag
in
gt_structures
:
if
'</td>'
in
tag
:
if
contents_with_block
[
td_index
]
!=
[]:
end_html
.
extend
(
contents_with_block
[
td_index
])
end_html
.
append
(
tag
)
td_index
+=
1
else
:
end_html
.
append
(
tag
)
return
''
.
join
(
end_html
),
end_html
if
__name__
==
'__main__'
:
args
=
utility
.
parse_args
()
gt_path
=
'table/match_code/f_gt_bbox.json'
img_root
=
'table/imgs'
main
(
gt_path
,
img_root
,
args
)
ppstructure/table/predict_structure.py
0 → 100755
View file @
79436248
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
numpy
as
np
import
math
import
time
import
traceback
import
paddle
import
tools.infer.utility
as
utility
from
ppocr.data
import
create_operators
,
transform
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
logger
=
get_logger
()
class
TableStructurer
(
object
):
def
__init__
(
self
,
args
):
pre_process_list
=
[{
'ResizeTableImage'
:
{
'max_len'
:
args
.
table_max_len
}
},
{
'NormalizeImage'
:
{
'std'
:
[
0.229
,
0.224
,
0.225
],
'mean'
:
[
0.485
,
0.456
,
0.406
],
'scale'
:
'1./255.'
,
'order'
:
'hwc'
}
},
{
'PaddingTableImage'
:
None
},
{
'ToCHWImage'
:
None
},
{
'KeepKeys'
:
{
'keep_keys'
:
[
'image'
]
}
}]
postprocess_params
=
{
'name'
:
'TableLabelDecode'
,
"character_type"
:
args
.
table_char_type
,
"character_dict_path"
:
args
.
table_char_dict_path
,
"max_text_length"
:
args
.
table_max_text_length
,
"max_elem_length"
:
args
.
table_max_elem_length
,
"max_cell_num"
:
args
.
table_max_cell_num
}
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
utility
.
create_predictor
(
args
,
'table'
,
logger
)
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
data
=
{
'image'
:
img
}
data
=
transform
(
data
,
self
.
preprocess_op
)
img
=
data
[
0
]
if
img
is
None
:
return
None
,
0
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
img
=
img
.
copy
()
starttime
=
time
.
time
()
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
preds
=
{}
preds
[
'structure_probs'
]
=
outputs
[
1
]
preds
[
'loc_preds'
]
=
outputs
[
0
]
post_result
=
self
.
postprocess_op
(
preds
)
structure_str_list
=
post_result
[
'structure_str_list'
]
res_loc
=
post_result
[
'res_loc'
]
imgh
,
imgw
=
ori_im
.
shape
[
0
:
2
]
res_loc_final
=
[]
for
rno
in
range
(
len
(
res_loc
[
0
])):
x0
,
y0
,
x1
,
y1
=
res_loc
[
0
][
rno
]
left
=
max
(
int
(
imgw
*
x0
),
0
)
top
=
max
(
int
(
imgh
*
y0
),
0
)
right
=
min
(
int
(
imgw
*
x1
),
imgw
-
1
)
bottom
=
min
(
int
(
imgh
*
y1
),
imgh
-
1
)
res_loc_final
.
append
([
left
,
top
,
right
,
bottom
])
structure_str_list
=
structure_str_list
[
0
][:
-
1
]
structure_str_list
=
[
'<html>'
,
'<body>'
,
'<table>'
]
+
structure_str_list
+
[
'</table>'
,
'</body>'
,
'</html>'
]
elapse
=
time
.
time
()
-
starttime
return
(
structure_str_list
,
res_loc_final
),
elapse
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
table_structurer
=
TableStructurer
(
args
)
count
=
0
total_time
=
0
for
image_file
in
image_file_list
:
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
structure_res
,
elapse
=
table_structurer
(
img
)
logger
.
info
(
"result: {}"
.
format
(
structure_res
))
if
count
>
0
:
total_time
+=
elapse
count
+=
1
logger
.
info
(
"Predict time of {}: {}"
.
format
(
image_file
,
elapse
))
if
__name__
==
"__main__"
:
main
(
utility
.
parse_args
())
ppstructure/table/predict_table.py
0 → 100644
View file @
79436248
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
subprocess
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
copy
import
numpy
as
np
import
time
import
tools.infer.utility
as
utility
import
tools.infer.predict_rec
as
predict_rec
import
tools.infer.predict_det
as
predict_det
import
ppstructure.table.predict_structure
as
predict_strture
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.table_utils.matcher
import
distance
,
compute_iou
logger
=
get_logger
()
def
expand
(
pix
,
det_box
,
shape
):
x0
,
y0
,
x1
,
y1
=
det_box
# print(shape)
h
,
w
,
c
=
shape
tmp_x0
=
x0
-
pix
tmp_x1
=
x1
+
pix
tmp_y0
=
y0
-
pix
tmp_y1
=
y1
+
pix
x0_
=
tmp_x0
if
tmp_x0
>=
0
else
0
x1_
=
tmp_x1
if
tmp_x1
<=
w
else
w
y0_
=
tmp_y0
if
tmp_y0
>=
0
else
0
y1_
=
tmp_y1
if
tmp_y1
<=
h
else
h
return
x0_
,
y0_
,
x1_
,
y1_
class
TableSystem
(
object
):
def
__init__
(
self
,
args
):
self
.
text_detector
=
predict_det
.
TextDetector
(
args
)
self
.
text_recognizer
=
predict_rec
.
TextRecognizer
(
args
)
self
.
table_structurer
=
predict_strture
.
TableStructurer
(
args
)
self
.
use_angle_cls
=
args
.
use_angle_cls
self
.
drop_score
=
args
.
drop_score
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
structure_res
,
elapse
=
self
.
table_structurer
(
copy
.
deepcopy
(
img
))
dt_boxes
,
elapse
=
self
.
text_detector
(
copy
.
deepcopy
(
img
))
dt_boxes
=
sorted_boxes
(
dt_boxes
)
r_boxes
=
[]
for
box
in
dt_boxes
:
x_min
=
box
[:,
0
].
min
()
-
1
x_max
=
box
[:,
0
].
max
()
+
1
y_min
=
box
[:,
1
].
min
()
-
1
y_max
=
box
[:,
1
].
max
()
+
1
box
=
[
x_min
,
y_min
,
x_max
,
y_max
]
r_boxes
.
append
(
box
)
dt_boxes
=
np
.
array
(
r_boxes
)
# logger.info("dt_boxes num : {}, elapse : {}".format(
# len(dt_boxes), elapse))
if
dt_boxes
is
None
:
return
None
,
None
img_crop_list
=
[]
for
i
in
range
(
len
(
dt_boxes
)):
det_box
=
dt_boxes
[
i
]
x0
,
y0
,
x1
,
y1
=
expand
(
2
,
det_box
,
ori_im
.
shape
)
text_rect
=
ori_im
[
int
(
y0
):
int
(
y1
),
int
(
x0
):
int
(
x1
),
:]
img_crop_list
.
append
(
text_rect
)
rec_res
,
elapse
=
self
.
text_recognizer
(
img_crop_list
)
# logger.info("rec_res num : {}, elapse : {}".format(
# len(rec_res), elapse))
pred_html
,
pred
=
self
.
rebuild_table
(
structure_res
,
dt_boxes
,
rec_res
)
return
pred_html
def
rebuild_table
(
self
,
structure_res
,
dt_boxes
,
rec_res
):
pred_structures
,
pred_bboxes
=
structure_res
matched_index
=
self
.
match_result
(
dt_boxes
,
pred_bboxes
)
pred_html
,
pred
=
self
.
get_pred_html
(
pred_structures
,
matched_index
,
rec_res
)
return
pred_html
,
pred
def
match_result
(
self
,
dt_boxes
,
pred_bboxes
):
matched
=
{}
for
i
,
gt_box
in
enumerate
(
dt_boxes
):
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
distances
=
[]
for
j
,
pred_box
in
enumerate
(
pred_bboxes
):
distances
.
append
(
(
distance
(
gt_box
,
pred_box
),
1.
-
compute_iou
(
gt_box
,
pred_box
)))
# 获取两两cell之间的L1距离和 1- IOU
sorted_distances
=
distances
.
copy
()
# 根据距离和IOU挑选最"近"的cell
sorted_distances
=
sorted
(
sorted_distances
,
key
=
lambda
item
:
(
item
[
1
],
item
[
0
]))
if
distances
.
index
(
sorted_distances
[
0
])
not
in
matched
.
keys
():
matched
[
distances
.
index
(
sorted_distances
[
0
])]
=
[
i
]
else
:
matched
[
distances
.
index
(
sorted_distances
[
0
])].
append
(
i
)
return
matched
def
get_pred_html
(
self
,
pred_structures
,
matched_index
,
ocr_contents
):
end_html
=
[]
td_index
=
0
for
tag
in
pred_structures
:
if
'</td>'
in
tag
:
if
td_index
in
matched_index
.
keys
():
b_with
=
False
if
'<b>'
in
ocr_contents
[
matched_index
[
td_index
][
0
]]
and
len
(
matched_index
[
td_index
])
>
1
:
b_with
=
True
end_html
.
extend
(
'<b>'
)
for
i
,
td_index_index
in
enumerate
(
matched_index
[
td_index
]):
content
=
ocr_contents
[
td_index_index
][
0
]
if
len
(
matched_index
[
td_index
])
>
1
:
if
len
(
content
)
==
0
:
continue
if
content
[
0
]
==
' '
:
content
=
content
[
1
:]
if
'<b>'
in
content
:
content
=
content
[
3
:]
if
'</b>'
in
content
:
content
=
content
[:
-
4
]
if
len
(
content
)
==
0
:
continue
if
i
!=
len
(
matched_index
[
td_index
])
-
1
and
' '
!=
content
[
-
1
]:
content
+=
' '
end_html
.
extend
(
content
)
if
b_with
:
end_html
.
extend
(
'</b>'
)
end_html
.
append
(
tag
)
td_index
+=
1
else
:
end_html
.
append
(
tag
)
return
''
.
join
(
end_html
),
end_html
def
sorted_boxes
(
dt_boxes
):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes
=
dt_boxes
.
shape
[
0
]
sorted_boxes
=
sorted
(
dt_boxes
,
key
=
lambda
x
:
(
x
[
0
][
1
],
x
[
0
][
0
]))
_boxes
=
list
(
sorted_boxes
)
for
i
in
range
(
num_boxes
-
1
):
if
abs
(
_boxes
[
i
+
1
][
0
][
1
]
-
_boxes
[
i
][
0
][
1
])
<
10
and
\
(
_boxes
[
i
+
1
][
0
][
0
]
<
_boxes
[
i
][
0
][
0
]):
tmp
=
_boxes
[
i
]
_boxes
[
i
]
=
_boxes
[
i
+
1
]
_boxes
[
i
+
1
]
=
tmp
return
_boxes
def
to_excel
(
html_table
,
excel_path
):
from
tablepyxl
import
tablepyxl
tablepyxl
.
document_to_xl
(
html_table
,
excel_path
)
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
image_file_list
=
image_file_list
[
args
.
process_id
::
args
.
total_process_num
]
excel_save_folder
=
'output/table'
os
.
makedirs
(
excel_save_folder
,
exist_ok
=
True
)
text_sys
=
TableSystem
(
args
)
img_num
=
len
(
image_file_list
)
for
i
,
image_file
in
enumerate
(
image_file_list
):
logger
.
info
(
"[{}/{}] {}"
.
format
(
i
,
img_num
,
image_file
))
img
,
flag
=
check_and_read_gif
(
image_file
)
excel_path
=
os
.
path
.
join
(
excel_save_folder
,
os
.
path
.
basename
(
image_file
).
split
(
'.'
)[
0
]
+
'.xlsx'
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
starttime
=
time
.
time
()
pred_html
=
text_sys
(
img
)
to_excel
(
pred_html
,
excel_path
)
logger
.
info
(
'excel saved to {}'
.
format
(
excel_path
))
logger
.
info
(
pred_html
)
elapse
=
time
.
time
()
-
starttime
logger
.
info
(
"Predict time : {:.3f}s"
.
format
(
elapse
))
if
__name__
==
"__main__"
:
args
=
utility
.
parse_args
()
if
args
.
use_mp
:
p_list
=
[]
total_process_num
=
args
.
total_process_num
for
process_id
in
range
(
total_process_num
):
cmd
=
[
sys
.
executable
,
"-u"
]
+
sys
.
argv
+
[
"--process_id={}"
.
format
(
process_id
),
"--use_mp={}"
.
format
(
False
)
]
p
=
subprocess
.
Popen
(
cmd
,
stdout
=
sys
.
stdout
,
stderr
=
sys
.
stdout
)
p_list
.
append
(
p
)
for
p
in
p_list
:
p
.
wait
()
else
:
main
(
args
)
ppstructure/table/table_metric/__init__.py
0 → 100755
View file @
79436248
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__
=
[
'TEDS'
]
from
.table_metric
import
TEDS
\ No newline at end of file
ppstructure/table/table_metric/parallel.py
0 → 100755
View file @
79436248
from
tqdm
import
tqdm
from
concurrent.futures
import
ProcessPoolExecutor
,
as_completed
def
parallel_process
(
array
,
function
,
n_jobs
=
16
,
use_kwargs
=
False
,
front_num
=
0
):
"""
A parallel version of the map function with a progress bar.
Args:
array (array-like): An array to iterate over.
function (function): A python function to apply to the elements of array
n_jobs (int, default=16): The number of cores to use
use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
keyword arguments to function
front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
Useful for catching bugs
Returns:
[function(array[0]), function(array[1]), ...]
"""
# We run the first few iterations serially to catch bugs
if
front_num
>
0
:
front
=
[
function
(
**
a
)
if
use_kwargs
else
function
(
a
)
for
a
in
array
[:
front_num
]]
else
:
front
=
[]
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
if
n_jobs
==
1
:
return
front
+
[
function
(
**
a
)
if
use_kwargs
else
function
(
a
)
for
a
in
tqdm
(
array
[
front_num
:])]
# Assemble the workers
with
ProcessPoolExecutor
(
max_workers
=
n_jobs
)
as
pool
:
# Pass the elements of array into function
if
use_kwargs
:
futures
=
[
pool
.
submit
(
function
,
**
a
)
for
a
in
array
[
front_num
:]]
else
:
futures
=
[
pool
.
submit
(
function
,
a
)
for
a
in
array
[
front_num
:]]
kwargs
=
{
'total'
:
len
(
futures
),
'unit'
:
'it'
,
'unit_scale'
:
True
,
'leave'
:
True
}
# Print out the progress as tasks complete
for
f
in
tqdm
(
as_completed
(
futures
),
**
kwargs
):
pass
out
=
[]
# Get the results from the futures.
for
i
,
future
in
tqdm
(
enumerate
(
futures
)):
try
:
out
.
append
(
future
.
result
())
except
Exception
as
e
:
out
.
append
(
e
)
return
front
+
out
tools/infer/utility.py
View file @
79436248
...
...
@@ -125,6 +125,8 @@ def create_predictor(args, mode, logger):
model_dir
=
args
.
cls_model_dir
elif
mode
==
'rec'
:
model_dir
=
args
.
rec_model_dir
elif
mode
==
'table'
:
model_dir
=
args
.
table_model_dir
else
:
model_dir
=
args
.
e2e_model_dir
...
...
@@ -244,7 +246,8 @@ def create_predictor(args, mode, logger):
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
config
.
switch_use_feed_fetch_ops
(
False
)
if
mode
==
'table'
:
config
.
switch_ir_optim
(
False
)
# create predictor
predictor
=
inference
.
create_predictor
(
config
)
input_names
=
predictor
.
get_input_names
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment