Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
037e17fc
Commit
037e17fc
authored
Jun 10, 2021
by
WenmuZhou
Browse files
merge dygraph
parent
6127aad9
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1421 additions
and
3 deletions
+1421
-3
ppstructure/table/eval_table.py
ppstructure/table/eval_table.py
+72
-0
ppstructure/table/matcher.py
ppstructure/table/matcher.py
+192
-0
ppstructure/table/predict_structure.py
ppstructure/table/predict_structure.py
+141
-0
ppstructure/table/predict_table.py
ppstructure/table/predict_table.py
+221
-0
ppstructure/table/table_metric/__init__.py
ppstructure/table/table_metric/__init__.py
+16
-0
ppstructure/table/table_metric/parallel.py
ppstructure/table/table_metric/parallel.py
+51
-0
ppstructure/table/table_metric/table_metric.py
ppstructure/table/table_metric/table_metric.py
+247
-0
ppstructure/table/tablepyxl/__init__.py
ppstructure/table/tablepyxl/__init__.py
+13
-0
ppstructure/table/tablepyxl/style.py
ppstructure/table/tablepyxl/style.py
+283
-0
ppstructure/table/tablepyxl/tablepyxl.py
ppstructure/table/tablepyxl/tablepyxl.py
+118
-0
ppstructure/utility.py
ppstructure/utility.py
+59
-0
tools/infer/predict_det.py
tools/infer/predict_det.py
+1
-1
tools/infer/utility.py
tools/infer/utility.py
+7
-2
No files found.
ppstructure/table/eval_table.py
0 → 100755
View file @
037e17fc
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
import
cv2
import
json
from
tqdm
import
tqdm
from
ppstructure.table.table_metric
import
TEDS
from
ppstructure.table.predict_table
import
TableSystem
from
ppstructure.utility
import
init_args
from
ppocr.utils.logging
import
get_logger
logger
=
get_logger
()
def
parse_args
():
parser
=
init_args
()
parser
.
add_argument
(
"--gt_path"
,
type
=
str
)
return
parser
.
parse_args
()
def
main
(
gt_path
,
img_root
,
args
):
teds
=
TEDS
(
n_jobs
=
16
)
text_sys
=
TableSystem
(
args
)
jsons_gt
=
json
.
load
(
open
(
gt_path
))
# gt
pred_htmls
=
[]
gt_htmls
=
[]
for
img_name
in
tqdm
(
jsons_gt
):
# read image
img
=
cv2
.
imread
(
os
.
path
.
join
(
img_root
,
img_name
))
pred_html
=
text_sys
(
img
)
pred_htmls
.
append
(
pred_html
)
gt_structures
,
gt_bboxes
,
gt_contents
,
contents_with_block
=
jsons_gt
[
img_name
]
gt_html
,
gt
=
get_gt_html
(
gt_structures
,
contents_with_block
)
gt_htmls
.
append
(
gt_html
)
scores
=
teds
.
batch_evaluate_html
(
gt_htmls
,
pred_htmls
)
logger
.
info
(
'teds:'
,
sum
(
scores
)
/
len
(
scores
))
def
get_gt_html
(
gt_structures
,
contents_with_block
):
end_html
=
[]
td_index
=
0
for
tag
in
gt_structures
:
if
'</td>'
in
tag
:
if
contents_with_block
[
td_index
]
!=
[]:
end_html
.
extend
(
contents_with_block
[
td_index
])
end_html
.
append
(
tag
)
td_index
+=
1
else
:
end_html
.
append
(
tag
)
return
''
.
join
(
end_html
),
end_html
if
__name__
==
'__main__'
:
args
=
parse_args
()
main
(
args
.
gt_path
,
args
.
image_dir
,
args
)
ppstructure/table/matcher.py
0 → 100755
View file @
037e17fc
import
json
def
distance
(
box_1
,
box_2
):
x1
,
y1
,
x2
,
y2
=
box_1
x3
,
y3
,
x4
,
y4
=
box_2
dis
=
abs
(
x3
-
x1
)
+
abs
(
y3
-
y1
)
+
abs
(
x4
-
x2
)
+
abs
(
y4
-
y2
)
dis_2
=
abs
(
x3
-
x1
)
+
abs
(
y3
-
y1
)
dis_3
=
abs
(
x4
-
x2
)
+
abs
(
y4
-
y2
)
return
dis
+
min
(
dis_2
,
dis_3
)
def
compute_iou
(
rec1
,
rec2
):
"""
computing IoU
:param rec1: (y0, x0, y1, x1), which reflects
(top, left, bottom, right)
:param rec2: (y0, x0, y1, x1)
:return: scala value of IoU
"""
# computing area of each rectangles
S_rec1
=
(
rec1
[
2
]
-
rec1
[
0
])
*
(
rec1
[
3
]
-
rec1
[
1
])
S_rec2
=
(
rec2
[
2
]
-
rec2
[
0
])
*
(
rec2
[
3
]
-
rec2
[
1
])
# computing the sum_area
sum_area
=
S_rec1
+
S_rec2
# find the each edge of intersect rectangle
left_line
=
max
(
rec1
[
1
],
rec2
[
1
])
right_line
=
min
(
rec1
[
3
],
rec2
[
3
])
top_line
=
max
(
rec1
[
0
],
rec2
[
0
])
bottom_line
=
min
(
rec1
[
2
],
rec2
[
2
])
# judge if there is an intersect
if
left_line
>=
right_line
or
top_line
>=
bottom_line
:
return
0.0
else
:
intersect
=
(
right_line
-
left_line
)
*
(
bottom_line
-
top_line
)
return
(
intersect
/
(
sum_area
-
intersect
))
*
1.0
def
matcher_merge
(
ocr_bboxes
,
pred_bboxes
):
all_dis
=
[]
ious
=
[]
matched
=
{}
for
i
,
gt_box
in
enumerate
(
ocr_bboxes
):
distances
=
[]
for
j
,
pred_box
in
enumerate
(
pred_bboxes
):
# compute l1 distence and IOU between two boxes
distances
.
append
((
distance
(
gt_box
,
pred_box
),
1.
-
compute_iou
(
gt_box
,
pred_box
)))
sorted_distances
=
distances
.
copy
()
# select nearest cell
sorted_distances
=
sorted
(
sorted_distances
,
key
=
lambda
item
:
(
item
[
1
],
item
[
0
]))
if
distances
.
index
(
sorted_distances
[
0
])
not
in
matched
.
keys
():
matched
[
distances
.
index
(
sorted_distances
[
0
])]
=
[
i
]
else
:
matched
[
distances
.
index
(
sorted_distances
[
0
])].
append
(
i
)
return
matched
#, sum(ious) / len(ious)
def
complex_num
(
pred_bboxes
):
complex_nums
=
[]
for
bbox
in
pred_bboxes
:
distances
=
[]
temp_ious
=
[]
for
pred_bbox
in
pred_bboxes
:
if
bbox
!=
pred_bbox
:
distances
.
append
(
distance
(
bbox
,
pred_bbox
))
temp_ious
.
append
(
compute_iou
(
bbox
,
pred_bbox
))
complex_nums
.
append
(
temp_ious
[
distances
.
index
(
min
(
distances
))])
return
sum
(
complex_nums
)
/
len
(
complex_nums
)
def
get_rows
(
pred_bboxes
):
pre_bbox
=
pred_bboxes
[
0
]
res
=
[]
step
=
0
for
i
in
range
(
len
(
pred_bboxes
)):
bbox
=
pred_bboxes
[
i
]
if
bbox
[
1
]
-
pre_bbox
[
1
]
>
2
or
bbox
[
0
]
-
pre_bbox
[
0
]
<
0
:
break
else
:
res
.
append
(
bbox
)
step
+=
1
for
i
in
range
(
step
):
pred_bboxes
.
pop
(
0
)
return
res
,
pred_bboxes
def
refine_rows
(
pred_bboxes
):
# 微调整行的框,使在一条水平线上
ys_1
=
[]
ys_2
=
[]
for
box
in
pred_bboxes
:
ys_1
.
append
(
box
[
1
])
ys_2
.
append
(
box
[
3
])
min_y_1
=
sum
(
ys_1
)
/
len
(
ys_1
)
min_y_2
=
sum
(
ys_2
)
/
len
(
ys_2
)
re_boxes
=
[]
for
box
in
pred_bboxes
:
box
[
1
]
=
min_y_1
box
[
3
]
=
min_y_2
re_boxes
.
append
(
box
)
return
re_boxes
def
matcher_refine_row
(
gt_bboxes
,
pred_bboxes
):
before_refine_pred_bboxes
=
pred_bboxes
.
copy
()
pred_bboxes
=
[]
while
(
len
(
before_refine_pred_bboxes
)
!=
0
):
row_bboxes
,
before_refine_pred_bboxes
=
get_rows
(
before_refine_pred_bboxes
)
print
(
row_bboxes
)
pred_bboxes
.
extend
(
refine_rows
(
row_bboxes
))
all_dis
=
[]
ious
=
[]
matched
=
{}
for
i
,
gt_box
in
enumerate
(
gt_bboxes
):
distances
=
[]
#temp_ious = []
for
j
,
pred_box
in
enumerate
(
pred_bboxes
):
distances
.
append
(
distance
(
gt_box
,
pred_box
))
#temp_ious.append(compute_iou(gt_box, pred_box))
#all_dis.append(min(distances))
#ious.append(temp_ious[distances.index(min(distances))])
if
distances
.
index
(
min
(
distances
))
not
in
matched
.
keys
():
matched
[
distances
.
index
(
min
(
distances
))]
=
[
i
]
else
:
matched
[
distances
.
index
(
min
(
distances
))].
append
(
i
)
return
matched
#, sum(ious) / len(ious)
#先挑选出一行,再进行匹配
def
matcher_structure_1
(
gt_bboxes
,
pred_bboxes_rows
,
pred_bboxes
):
gt_box_index
=
0
delete_gt_bboxes
=
gt_bboxes
.
copy
()
match_bboxes_ready
=
[]
matched
=
{}
while
(
len
(
delete_gt_bboxes
)
!=
0
):
row_bboxes
,
delete_gt_bboxes
=
get_rows
(
delete_gt_bboxes
)
row_bboxes
=
sorted
(
row_bboxes
,
key
=
lambda
key
:
key
[
0
])
if
len
(
pred_bboxes_rows
)
>
0
:
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
print
(
row_bboxes
)
for
i
,
gt_box
in
enumerate
(
row_bboxes
):
#print(gt_box)
pred_distances
=
[]
distances
=
[]
for
pred_bbox
in
pred_bboxes
:
pred_distances
.
append
(
distance
(
gt_box
,
pred_bbox
))
for
j
,
pred_box
in
enumerate
(
match_bboxes_ready
):
distances
.
append
(
distance
(
gt_box
,
pred_box
))
index
=
pred_distances
.
index
(
min
(
distances
))
#print('index', index)
if
index
not
in
matched
.
keys
():
matched
[
index
]
=
[
gt_box_index
]
else
:
matched
[
index
].
append
(
gt_box_index
)
gt_box_index
+=
1
return
matched
def
matcher_structure
(
gt_bboxes
,
pred_bboxes_rows
,
pred_bboxes
):
'''
gt_bboxes: 排序后
pred_bboxes:
'''
pre_bbox
=
gt_bboxes
[
0
]
matched
=
{}
match_bboxes_ready
=
[]
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
for
i
,
gt_box
in
enumerate
(
gt_bboxes
):
pred_distances
=
[]
for
pred_bbox
in
pred_bboxes
:
pred_distances
.
append
(
distance
(
gt_box
,
pred_bbox
))
distances
=
[]
gap_pre
=
gt_box
[
1
]
-
pre_bbox
[
1
]
gap_pre_1
=
gt_box
[
0
]
-
pre_bbox
[
2
]
#print(gap_pre, len(pred_bboxes_rows))
if
(
gap_pre_1
<
0
and
len
(
pred_bboxes_rows
)
>
0
):
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
if
len
(
pred_bboxes_rows
)
==
1
:
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
if
len
(
match_bboxes_ready
)
==
0
and
len
(
pred_bboxes_rows
)
>
0
:
match_bboxes_ready
.
extend
(
pred_bboxes_rows
.
pop
(
0
))
if
len
(
match_bboxes_ready
)
==
0
and
len
(
pred_bboxes_rows
)
==
0
:
break
#print(match_bboxes_ready)
for
j
,
pred_box
in
enumerate
(
match_bboxes_ready
):
distances
.
append
(
distance
(
gt_box
,
pred_box
))
index
=
pred_distances
.
index
(
min
(
distances
))
#print(gt_box, index)
#match_bboxes_ready.pop(distances.index(min(distances)))
print
(
gt_box
,
match_bboxes_ready
[
distances
.
index
(
min
(
distances
))])
if
index
not
in
matched
.
keys
():
matched
[
index
]
=
[
i
]
else
:
matched
[
index
].
append
(
i
)
pre_bbox
=
gt_box
return
matched
ppstructure/table/predict_structure.py
0 → 100755
View file @
037e17fc
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
numpy
as
np
import
math
import
time
import
traceback
import
paddle
import
tools.infer.utility
as
utility
from
ppocr.data
import
create_operators
,
transform
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
logger
=
get_logger
()
class
TableStructurer
(
object
):
def
__init__
(
self
,
args
):
pre_process_list
=
[{
'ResizeTableImage'
:
{
'max_len'
:
args
.
structure_max_len
}
},
{
'NormalizeImage'
:
{
'std'
:
[
0.229
,
0.224
,
0.225
],
'mean'
:
[
0.485
,
0.456
,
0.406
],
'scale'
:
'1./255.'
,
'order'
:
'hwc'
}
},
{
'PaddingTableImage'
:
None
},
{
'ToCHWImage'
:
None
},
{
'KeepKeys'
:
{
'keep_keys'
:
[
'image'
]
}
}]
postprocess_params
=
{
'name'
:
'TableLabelDecode'
,
"character_type"
:
args
.
structure_char_type
,
"character_dict_path"
:
args
.
structure_char_dict_path
,
"max_text_length"
:
args
.
structure_max_text_length
,
"max_elem_length"
:
args
.
structure_max_elem_length
,
"max_cell_num"
:
args
.
structure_max_cell_num
}
self
.
preprocess_op
=
create_operators
(
pre_process_list
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
=
\
utility
.
create_predictor
(
args
,
'structure'
,
logger
)
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
data
=
{
'image'
:
img
}
data
=
transform
(
data
,
self
.
preprocess_op
)
img
=
data
[
0
]
if
img
is
None
:
return
None
,
0
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
img
=
img
.
copy
()
starttime
=
time
.
time
()
self
.
input_tensor
.
copy_from_cpu
(
img
)
self
.
predictor
.
run
()
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
preds
=
{}
preds
[
'structure_probs'
]
=
outputs
[
1
]
preds
[
'loc_preds'
]
=
outputs
[
0
]
post_result
=
self
.
postprocess_op
(
preds
)
structure_str_list
=
post_result
[
'structure_str_list'
]
res_loc
=
post_result
[
'res_loc'
]
imgh
,
imgw
=
ori_im
.
shape
[
0
:
2
]
res_loc_final
=
[]
for
rno
in
range
(
len
(
res_loc
[
0
])):
x0
,
y0
,
x1
,
y1
=
res_loc
[
0
][
rno
]
left
=
max
(
int
(
imgw
*
x0
),
0
)
top
=
max
(
int
(
imgh
*
y0
),
0
)
right
=
min
(
int
(
imgw
*
x1
),
imgw
-
1
)
bottom
=
min
(
int
(
imgh
*
y1
),
imgh
-
1
)
res_loc_final
.
append
([
left
,
top
,
right
,
bottom
])
structure_str_list
=
structure_str_list
[
0
][:
-
1
]
structure_str_list
=
[
'<html>'
,
'<body>'
,
'<table>'
]
+
structure_str_list
+
[
'</table>'
,
'</body>'
,
'</html>'
]
elapse
=
time
.
time
()
-
starttime
return
(
structure_str_list
,
res_loc_final
),
elapse
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
table_structurer
=
TableStructurer
(
args
)
count
=
0
total_time
=
0
for
image_file
in
image_file_list
:
img
,
flag
=
check_and_read_gif
(
image_file
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
info
(
"error in loading image:{}"
.
format
(
image_file
))
continue
structure_res
,
elapse
=
table_structurer
(
img
)
logger
.
info
(
"result: {}"
.
format
(
structure_res
))
if
count
>
0
:
total_time
+=
elapse
count
+=
1
logger
.
info
(
"Predict time of {}: {}"
.
format
(
image_file
,
elapse
))
if
__name__
==
"__main__"
:
main
(
utility
.
parse_args
())
ppstructure/table/predict_table.py
0 → 100644
View file @
037e17fc
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
subprocess
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'..'
)))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
os
.
environ
[
"FLAGS_allocator_strategy"
]
=
'auto_growth'
import
cv2
import
copy
import
numpy
as
np
import
time
import
tools.infer.predict_rec
as
predict_rec
import
tools.infer.predict_det
as
predict_det
from
ppocr.utils.utility
import
get_image_file_list
,
check_and_read_gif
from
ppocr.utils.logging
import
get_logger
from
ppstructure.table.matcher
import
distance
,
compute_iou
from
ppstructure.utility
import
parse_args
import
ppstructure.table.predict_structure
as
predict_strture
logger
=
get_logger
()
def
expand
(
pix
,
det_box
,
shape
):
x0
,
y0
,
x1
,
y1
=
det_box
# print(shape)
h
,
w
,
c
=
shape
tmp_x0
=
x0
-
pix
tmp_x1
=
x1
+
pix
tmp_y0
=
y0
-
pix
tmp_y1
=
y1
+
pix
x0_
=
tmp_x0
if
tmp_x0
>=
0
else
0
x1_
=
tmp_x1
if
tmp_x1
<=
w
else
w
y0_
=
tmp_y0
if
tmp_y0
>=
0
else
0
y1_
=
tmp_y1
if
tmp_y1
<=
h
else
h
return
x0_
,
y0_
,
x1_
,
y1_
class
TableSystem
(
object
):
def
__init__
(
self
,
args
,
text_detector
=
None
,
text_recognizer
=
None
):
self
.
text_detector
=
predict_det
.
TextDetector
(
args
)
if
text_detector
is
None
else
text_detector
self
.
text_recognizer
=
predict_rec
.
TextRecognizer
(
args
)
if
text_recognizer
is
None
else
text_recognizer
self
.
table_structurer
=
predict_strture
.
TableStructurer
(
args
)
def
__call__
(
self
,
img
):
ori_im
=
img
.
copy
()
structure_res
,
elapse
=
self
.
table_structurer
(
copy
.
deepcopy
(
img
))
dt_boxes
,
elapse
=
self
.
text_detector
(
copy
.
deepcopy
(
img
))
dt_boxes
=
sorted_boxes
(
dt_boxes
)
r_boxes
=
[]
for
box
in
dt_boxes
:
x_min
=
box
[:,
0
].
min
()
-
1
x_max
=
box
[:,
0
].
max
()
+
1
y_min
=
box
[:,
1
].
min
()
-
1
y_max
=
box
[:,
1
].
max
()
+
1
box
=
[
x_min
,
y_min
,
x_max
,
y_max
]
r_boxes
.
append
(
box
)
dt_boxes
=
np
.
array
(
r_boxes
)
logger
.
debug
(
"dt_boxes num : {}, elapse : {}"
.
format
(
len
(
dt_boxes
),
elapse
))
if
dt_boxes
is
None
:
return
None
,
None
img_crop_list
=
[]
for
i
in
range
(
len
(
dt_boxes
)):
det_box
=
dt_boxes
[
i
]
x0
,
y0
,
x1
,
y1
=
expand
(
2
,
det_box
,
ori_im
.
shape
)
text_rect
=
ori_im
[
int
(
y0
):
int
(
y1
),
int
(
x0
):
int
(
x1
),
:]
img_crop_list
.
append
(
text_rect
)
rec_res
,
elapse
=
self
.
text_recognizer
(
img_crop_list
)
logger
.
debug
(
"rec_res num : {}, elapse : {}"
.
format
(
len
(
rec_res
),
elapse
))
pred_html
,
pred
=
self
.
rebuild_table
(
structure_res
,
dt_boxes
,
rec_res
)
return
pred_html
def
rebuild_table
(
self
,
structure_res
,
dt_boxes
,
rec_res
):
pred_structures
,
pred_bboxes
=
structure_res
matched_index
=
self
.
match_result
(
dt_boxes
,
pred_bboxes
)
pred_html
,
pred
=
self
.
get_pred_html
(
pred_structures
,
matched_index
,
rec_res
)
return
pred_html
,
pred
def
match_result
(
self
,
dt_boxes
,
pred_bboxes
):
matched
=
{}
for
i
,
gt_box
in
enumerate
(
dt_boxes
):
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
distances
=
[]
for
j
,
pred_box
in
enumerate
(
pred_bboxes
):
distances
.
append
(
(
distance
(
gt_box
,
pred_box
),
1.
-
compute_iou
(
gt_box
,
pred_box
)))
# 获取两两cell之间的L1距离和 1- IOU
sorted_distances
=
distances
.
copy
()
# 根据距离和IOU挑选最"近"的cell
sorted_distances
=
sorted
(
sorted_distances
,
key
=
lambda
item
:
(
item
[
1
],
item
[
0
]))
if
distances
.
index
(
sorted_distances
[
0
])
not
in
matched
.
keys
():
matched
[
distances
.
index
(
sorted_distances
[
0
])]
=
[
i
]
else
:
matched
[
distances
.
index
(
sorted_distances
[
0
])].
append
(
i
)
return
matched
def
get_pred_html
(
self
,
pred_structures
,
matched_index
,
ocr_contents
):
end_html
=
[]
td_index
=
0
for
tag
in
pred_structures
:
if
'</td>'
in
tag
:
if
td_index
in
matched_index
.
keys
():
b_with
=
False
if
'<b>'
in
ocr_contents
[
matched_index
[
td_index
][
0
]]
and
len
(
matched_index
[
td_index
])
>
1
:
b_with
=
True
end_html
.
extend
(
'<b>'
)
for
i
,
td_index_index
in
enumerate
(
matched_index
[
td_index
]):
content
=
ocr_contents
[
td_index_index
][
0
]
if
len
(
matched_index
[
td_index
])
>
1
:
if
len
(
content
)
==
0
:
continue
if
content
[
0
]
==
' '
:
content
=
content
[
1
:]
if
'<b>'
in
content
:
content
=
content
[
3
:]
if
'</b>'
in
content
:
content
=
content
[:
-
4
]
if
len
(
content
)
==
0
:
continue
if
i
!=
len
(
matched_index
[
td_index
])
-
1
and
' '
!=
content
[
-
1
]:
content
+=
' '
end_html
.
extend
(
content
)
if
b_with
:
end_html
.
extend
(
'</b>'
)
end_html
.
append
(
tag
)
td_index
+=
1
else
:
end_html
.
append
(
tag
)
return
''
.
join
(
end_html
),
end_html
def
sorted_boxes
(
dt_boxes
):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes
=
dt_boxes
.
shape
[
0
]
sorted_boxes
=
sorted
(
dt_boxes
,
key
=
lambda
x
:
(
x
[
0
][
1
],
x
[
0
][
0
]))
_boxes
=
list
(
sorted_boxes
)
for
i
in
range
(
num_boxes
-
1
):
if
abs
(
_boxes
[
i
+
1
][
0
][
1
]
-
_boxes
[
i
][
0
][
1
])
<
10
and
\
(
_boxes
[
i
+
1
][
0
][
0
]
<
_boxes
[
i
][
0
][
0
]):
tmp
=
_boxes
[
i
]
_boxes
[
i
]
=
_boxes
[
i
+
1
]
_boxes
[
i
+
1
]
=
tmp
return
_boxes
def
to_excel
(
html_table
,
excel_path
):
from
tablepyxl
import
tablepyxl
tablepyxl
.
document_to_xl
(
html_table
,
excel_path
)
def
main
(
args
):
image_file_list
=
get_image_file_list
(
args
.
image_dir
)
image_file_list
=
image_file_list
[
args
.
process_id
::
args
.
total_process_num
]
os
.
makedirs
(
args
.
output
,
exist_ok
=
True
)
text_sys
=
TableSystem
(
args
)
img_num
=
len
(
image_file_list
)
for
i
,
image_file
in
enumerate
(
image_file_list
):
logger
.
info
(
"[{}/{}] {}"
.
format
(
i
,
img_num
,
image_file
))
img
,
flag
=
check_and_read_gif
(
image_file
)
excel_path
=
os
.
path
.
join
(
args
.
table_output
,
os
.
path
.
basename
(
image_file
).
split
(
'.'
)[
0
]
+
'.xlsx'
)
if
not
flag
:
img
=
cv2
.
imread
(
image_file
)
if
img
is
None
:
logger
.
error
(
"error in loading image:{}"
.
format
(
image_file
))
continue
starttime
=
time
.
time
()
pred_html
=
text_sys
(
img
)
to_excel
(
pred_html
,
excel_path
)
logger
.
info
(
'excel saved to {}'
.
format
(
excel_path
))
logger
.
info
(
pred_html
)
elapse
=
time
.
time
()
-
starttime
logger
.
info
(
"Predict time : {:.3f}s"
.
format
(
elapse
))
if
__name__
==
"__main__"
:
args
=
parse_args
()
if
args
.
use_mp
:
p_list
=
[]
total_process_num
=
args
.
total_process_num
for
process_id
in
range
(
total_process_num
):
cmd
=
[
sys
.
executable
,
"-u"
]
+
sys
.
argv
+
[
"--process_id={}"
.
format
(
process_id
),
"--use_mp={}"
.
format
(
False
)
]
p
=
subprocess
.
Popen
(
cmd
,
stdout
=
sys
.
stdout
,
stderr
=
sys
.
stdout
)
p_list
.
append
(
p
)
for
p
in
p_list
:
p
.
wait
()
else
:
main
(
args
)
ppstructure/table/table_metric/__init__.py
0 → 100755
View file @
037e17fc
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__
=
[
'TEDS'
]
from
.table_metric
import
TEDS
\ No newline at end of file
ppstructure/table/table_metric/parallel.py
0 → 100755
View file @
037e17fc
from
tqdm
import
tqdm
from
concurrent.futures
import
ProcessPoolExecutor
,
as_completed
def
parallel_process
(
array
,
function
,
n_jobs
=
16
,
use_kwargs
=
False
,
front_num
=
0
):
"""
A parallel version of the map function with a progress bar.
Args:
array (array-like): An array to iterate over.
function (function): A python function to apply to the elements of array
n_jobs (int, default=16): The number of cores to use
use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
keyword arguments to function
front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
Useful for catching bugs
Returns:
[function(array[0]), function(array[1]), ...]
"""
# We run the first few iterations serially to catch bugs
if
front_num
>
0
:
front
=
[
function
(
**
a
)
if
use_kwargs
else
function
(
a
)
for
a
in
array
[:
front_num
]]
else
:
front
=
[]
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
if
n_jobs
==
1
:
return
front
+
[
function
(
**
a
)
if
use_kwargs
else
function
(
a
)
for
a
in
tqdm
(
array
[
front_num
:])]
# Assemble the workers
with
ProcessPoolExecutor
(
max_workers
=
n_jobs
)
as
pool
:
# Pass the elements of array into function
if
use_kwargs
:
futures
=
[
pool
.
submit
(
function
,
**
a
)
for
a
in
array
[
front_num
:]]
else
:
futures
=
[
pool
.
submit
(
function
,
a
)
for
a
in
array
[
front_num
:]]
kwargs
=
{
'total'
:
len
(
futures
),
'unit'
:
'it'
,
'unit_scale'
:
True
,
'leave'
:
True
}
# Print out the progress as tasks complete
for
f
in
tqdm
(
as_completed
(
futures
),
**
kwargs
):
pass
out
=
[]
# Get the results from the futures.
for
i
,
future
in
tqdm
(
enumerate
(
futures
)):
try
:
out
.
append
(
future
.
result
())
except
Exception
as
e
:
out
.
append
(
e
)
return
front
+
out
ppstructure/table/table_metric/table_metric.py
0 → 100755
View file @
037e17fc
# Copyright 2020 IBM
# Author: peter.zhong@au1.ibm.com
#
# This is free software; you can redistribute it and/or modify
# it under the terms of the Apache 2.0 License.
#
# This software is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# Apache 2.0 License for more details.
import
distance
from
apted
import
APTED
,
Config
from
apted.helpers
import
Tree
from
lxml
import
etree
,
html
from
collections
import
deque
from
.parallel
import
parallel_process
from
tqdm
import
tqdm
class
TableTree
(
Tree
):
def
__init__
(
self
,
tag
,
colspan
=
None
,
rowspan
=
None
,
content
=
None
,
*
children
):
self
.
tag
=
tag
self
.
colspan
=
colspan
self
.
rowspan
=
rowspan
self
.
content
=
content
self
.
children
=
list
(
children
)
def
bracket
(
self
):
"""Show tree using brackets notation"""
if
self
.
tag
==
'td'
:
result
=
'"tag": %s, "colspan": %d, "rowspan": %d, "text": %s'
%
\
(
self
.
tag
,
self
.
colspan
,
self
.
rowspan
,
self
.
content
)
else
:
result
=
'"tag": %s'
%
self
.
tag
for
child
in
self
.
children
:
result
+=
child
.
bracket
()
return
"{{{}}}"
.
format
(
result
)
class
CustomConfig
(
Config
):
@
staticmethod
def
maximum
(
*
sequences
):
"""Get maximum possible value
"""
return
max
(
map
(
len
,
sequences
))
def
normalized_distance
(
self
,
*
sequences
):
"""Get distance from 0 to 1
"""
return
float
(
distance
.
levenshtein
(
*
sequences
))
/
self
.
maximum
(
*
sequences
)
def
rename
(
self
,
node1
,
node2
):
"""Compares attributes of trees"""
#print(node1.tag)
if
(
node1
.
tag
!=
node2
.
tag
)
or
(
node1
.
colspan
!=
node2
.
colspan
)
or
(
node1
.
rowspan
!=
node2
.
rowspan
):
return
1.
if
node1
.
tag
==
'td'
:
if
node1
.
content
or
node2
.
content
:
#print(node1.content, )
return
self
.
normalized_distance
(
node1
.
content
,
node2
.
content
)
return
0.
class
CustomConfig_del_short
(
Config
):
@
staticmethod
def
maximum
(
*
sequences
):
"""Get maximum possible value
"""
return
max
(
map
(
len
,
sequences
))
def
normalized_distance
(
self
,
*
sequences
):
"""Get distance from 0 to 1
"""
return
float
(
distance
.
levenshtein
(
*
sequences
))
/
self
.
maximum
(
*
sequences
)
def
rename
(
self
,
node1
,
node2
):
"""Compares attributes of trees"""
if
(
node1
.
tag
!=
node2
.
tag
)
or
(
node1
.
colspan
!=
node2
.
colspan
)
or
(
node1
.
rowspan
!=
node2
.
rowspan
):
return
1.
if
node1
.
tag
==
'td'
:
if
node1
.
content
or
node2
.
content
:
#print('before')
#print(node1.content, node2.content)
#print('after')
node1_content
=
node1
.
content
node2_content
=
node2
.
content
if
len
(
node1_content
)
<
3
:
node1_content
=
[
'####'
]
if
len
(
node2_content
)
<
3
:
node2_content
=
[
'####'
]
return
self
.
normalized_distance
(
node1_content
,
node2_content
)
return
0.
class
CustomConfig_del_block
(
Config
):
@
staticmethod
def
maximum
(
*
sequences
):
"""Get maximum possible value
"""
return
max
(
map
(
len
,
sequences
))
def
normalized_distance
(
self
,
*
sequences
):
"""Get distance from 0 to 1
"""
return
float
(
distance
.
levenshtein
(
*
sequences
))
/
self
.
maximum
(
*
sequences
)
def
rename
(
self
,
node1
,
node2
):
"""Compares attributes of trees"""
if
(
node1
.
tag
!=
node2
.
tag
)
or
(
node1
.
colspan
!=
node2
.
colspan
)
or
(
node1
.
rowspan
!=
node2
.
rowspan
):
return
1.
if
node1
.
tag
==
'td'
:
if
node1
.
content
or
node2
.
content
:
node1_content
=
node1
.
content
node2_content
=
node2
.
content
while
' '
in
node1_content
:
print
(
node1_content
.
index
(
' '
))
node1_content
.
pop
(
node1_content
.
index
(
' '
))
while
' '
in
node2_content
:
print
(
node2_content
.
index
(
' '
))
node2_content
.
pop
(
node2_content
.
index
(
' '
))
return
self
.
normalized_distance
(
node1_content
,
node2_content
)
return
0.
class
TEDS
(
object
):
''' Tree Edit Distance basead Similarity
'''
def
__init__
(
self
,
structure_only
=
False
,
n_jobs
=
1
,
ignore_nodes
=
None
):
assert
isinstance
(
n_jobs
,
int
)
and
(
n_jobs
>=
1
),
'n_jobs must be an integer greather than 1'
self
.
structure_only
=
structure_only
self
.
n_jobs
=
n_jobs
self
.
ignore_nodes
=
ignore_nodes
self
.
__tokens__
=
[]
def
tokenize
(
self
,
node
):
''' Tokenizes table cells
'''
self
.
__tokens__
.
append
(
'<%s>'
%
node
.
tag
)
if
node
.
text
is
not
None
:
self
.
__tokens__
+=
list
(
node
.
text
)
for
n
in
node
.
getchildren
():
self
.
tokenize
(
n
)
if
node
.
tag
!=
'unk'
:
self
.
__tokens__
.
append
(
'</%s>'
%
node
.
tag
)
if
node
.
tag
!=
'td'
and
node
.
tail
is
not
None
:
self
.
__tokens__
+=
list
(
node
.
tail
)
def
load_html_tree
(
self
,
node
,
parent
=
None
):
''' Converts HTML tree to the format required by apted
'''
global
__tokens__
if
node
.
tag
==
'td'
:
if
self
.
structure_only
:
cell
=
[]
else
:
self
.
__tokens__
=
[]
self
.
tokenize
(
node
)
cell
=
self
.
__tokens__
[
1
:
-
1
].
copy
()
new_node
=
TableTree
(
node
.
tag
,
int
(
node
.
attrib
.
get
(
'colspan'
,
'1'
)),
int
(
node
.
attrib
.
get
(
'rowspan'
,
'1'
)),
cell
,
*
deque
())
else
:
new_node
=
TableTree
(
node
.
tag
,
None
,
None
,
None
,
*
deque
())
if
parent
is
not
None
:
parent
.
children
.
append
(
new_node
)
if
node
.
tag
!=
'td'
:
for
n
in
node
.
getchildren
():
self
.
load_html_tree
(
n
,
new_node
)
if
parent
is
None
:
return
new_node
def
evaluate
(
self
,
pred
,
true
):
''' Computes TEDS score between the prediction and the ground truth of a
given sample
'''
if
(
not
pred
)
or
(
not
true
):
return
0.0
parser
=
html
.
HTMLParser
(
remove_comments
=
True
,
encoding
=
'utf-8'
)
pred
=
html
.
fromstring
(
pred
,
parser
=
parser
)
true
=
html
.
fromstring
(
true
,
parser
=
parser
)
if
pred
.
xpath
(
'body/table'
)
and
true
.
xpath
(
'body/table'
):
pred
=
pred
.
xpath
(
'body/table'
)[
0
]
true
=
true
.
xpath
(
'body/table'
)[
0
]
if
self
.
ignore_nodes
:
etree
.
strip_tags
(
pred
,
*
self
.
ignore_nodes
)
etree
.
strip_tags
(
true
,
*
self
.
ignore_nodes
)
n_nodes_pred
=
len
(
pred
.
xpath
(
".//*"
))
n_nodes_true
=
len
(
true
.
xpath
(
".//*"
))
n_nodes
=
max
(
n_nodes_pred
,
n_nodes_true
)
tree_pred
=
self
.
load_html_tree
(
pred
)
tree_true
=
self
.
load_html_tree
(
true
)
distance
=
APTED
(
tree_pred
,
tree_true
,
CustomConfig
()).
compute_edit_distance
()
return
1.0
-
(
float
(
distance
)
/
n_nodes
)
else
:
return
0.0
def
batch_evaluate
(
self
,
pred_json
,
true_json
):
''' Computes TEDS score between the prediction and the ground truth of
a batch of samples
@params pred_json: {'FILENAME': 'HTML CODE', ...}
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
@output: {'FILENAME': 'TEDS SCORE', ...}
'''
samples
=
true_json
.
keys
()
if
self
.
n_jobs
==
1
:
scores
=
[
self
.
evaluate
(
pred_json
.
get
(
filename
,
''
),
true_json
[
filename
][
'html'
])
for
filename
in
tqdm
(
samples
)]
else
:
inputs
=
[{
'pred'
:
pred_json
.
get
(
filename
,
''
),
'true'
:
true_json
[
filename
][
'html'
]}
for
filename
in
samples
]
scores
=
parallel_process
(
inputs
,
self
.
evaluate
,
use_kwargs
=
True
,
n_jobs
=
self
.
n_jobs
,
front_num
=
1
)
scores
=
dict
(
zip
(
samples
,
scores
))
return
scores
def
batch_evaluate_html
(
self
,
pred_htmls
,
true_htmls
):
''' Computes TEDS score between the prediction and the ground truth of
a batch of samples
'''
if
self
.
n_jobs
==
1
:
scores
=
[
self
.
evaluate
(
pred_html
,
true_html
)
for
(
pred_html
,
true_html
)
in
zip
(
pred_htmls
,
true_htmls
)]
else
:
inputs
=
[{
"pred"
:
pred_html
,
"true"
:
true_html
}
for
(
pred_html
,
true_html
)
in
zip
(
pred_htmls
,
true_htmls
)]
scores
=
parallel_process
(
inputs
,
self
.
evaluate
,
use_kwargs
=
True
,
n_jobs
=
self
.
n_jobs
,
front_num
=
1
)
return
scores
if
__name__
==
'__main__'
:
import
json
import
pprint
with
open
(
'sample_pred.json'
)
as
fp
:
pred_json
=
json
.
load
(
fp
)
with
open
(
'sample_gt.json'
)
as
fp
:
true_json
=
json
.
load
(
fp
)
teds
=
TEDS
(
n_jobs
=
4
)
scores
=
teds
.
batch_evaluate
(
pred_json
,
true_json
)
pp
=
pprint
.
PrettyPrinter
()
pp
.
pprint
(
scores
)
ppstructure/table/tablepyxl/__init__.py
0 → 100644
View file @
037e17fc
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
\ No newline at end of file
ppstructure/table/tablepyxl/style.py
0 → 100644
View file @
037e17fc
# This is where we handle translating css styles into openpyxl styles
# and cascading those from parent to child in the dom.
from
openpyxl.cell
import
cell
from
openpyxl.styles
import
Font
,
Alignment
,
PatternFill
,
NamedStyle
,
Border
,
Side
,
Color
from
openpyxl.styles.fills
import
FILL_SOLID
from
openpyxl.styles.numbers
import
FORMAT_CURRENCY_USD_SIMPLE
,
FORMAT_PERCENTAGE
from
openpyxl.styles.colors
import
BLACK
FORMAT_DATE_MMDDYYYY
=
'mm/dd/yyyy'
def
colormap
(
color
):
"""
Convenience for looking up known colors
"""
cmap
=
{
'black'
:
BLACK
}
return
cmap
.
get
(
color
,
color
)
def
style_string_to_dict
(
style
):
"""
Convert css style string to a python dictionary
"""
def
clean_split
(
string
,
delim
):
return
(
s
.
strip
()
for
s
in
string
.
split
(
delim
))
styles
=
[
clean_split
(
s
,
":"
)
for
s
in
style
.
split
(
";"
)
if
":"
in
s
]
return
dict
(
styles
)
def
get_side
(
style
,
name
):
return
{
'border_style'
:
style
.
get
(
'border-{}-style'
.
format
(
name
)),
'color'
:
colormap
(
style
.
get
(
'border-{}-color'
.
format
(
name
)))}
known_styles
=
{}
def
style_dict_to_named_style
(
style_dict
,
number_format
=
None
):
"""
Change css style (stored in a python dictionary) to openpyxl NamedStyle
"""
style_and_format_string
=
str
({
'style_dict'
:
style_dict
,
'parent'
:
style_dict
.
parent
,
'number_format'
:
number_format
,
})
if
style_and_format_string
not
in
known_styles
:
# Font
font
=
Font
(
bold
=
style_dict
.
get
(
'font-weight'
)
==
'bold'
,
color
=
style_dict
.
get_color
(
'color'
,
None
),
size
=
style_dict
.
get
(
'font-size'
))
# Alignment
alignment
=
Alignment
(
horizontal
=
style_dict
.
get
(
'text-align'
,
'general'
),
vertical
=
style_dict
.
get
(
'vertical-align'
),
wrap_text
=
style_dict
.
get
(
'white-space'
,
'nowrap'
)
==
'normal'
)
# Fill
bg_color
=
style_dict
.
get_color
(
'background-color'
)
fg_color
=
style_dict
.
get_color
(
'foreground-color'
,
Color
())
fill_type
=
style_dict
.
get
(
'fill-type'
)
if
bg_color
and
bg_color
!=
'transparent'
:
fill
=
PatternFill
(
fill_type
=
fill_type
or
FILL_SOLID
,
start_color
=
bg_color
,
end_color
=
fg_color
)
else
:
fill
=
PatternFill
()
# Border
border
=
Border
(
left
=
Side
(
**
get_side
(
style_dict
,
'left'
)),
right
=
Side
(
**
get_side
(
style_dict
,
'right'
)),
top
=
Side
(
**
get_side
(
style_dict
,
'top'
)),
bottom
=
Side
(
**
get_side
(
style_dict
,
'bottom'
)),
diagonal
=
Side
(
**
get_side
(
style_dict
,
'diagonal'
)),
diagonal_direction
=
None
,
outline
=
Side
(
**
get_side
(
style_dict
,
'outline'
)),
vertical
=
None
,
horizontal
=
None
)
name
=
'Style {}'
.
format
(
len
(
known_styles
)
+
1
)
pyxl_style
=
NamedStyle
(
name
=
name
,
font
=
font
,
fill
=
fill
,
alignment
=
alignment
,
border
=
border
,
number_format
=
number_format
)
known_styles
[
style_and_format_string
]
=
pyxl_style
return
known_styles
[
style_and_format_string
]
class
StyleDict
(
dict
):
"""
It's like a dictionary, but it looks for items in the parent dictionary
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
parent
=
kwargs
.
pop
(
'parent'
,
None
)
super
(
StyleDict
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
__getitem__
(
self
,
item
):
if
item
in
self
:
return
super
(
StyleDict
,
self
).
__getitem__
(
item
)
elif
self
.
parent
:
return
self
.
parent
[
item
]
else
:
raise
KeyError
(
'{} not found'
.
format
(
item
))
def
__hash__
(
self
):
return
hash
(
tuple
([(
k
,
self
.
get
(
k
))
for
k
in
self
.
_keys
()]))
# Yielding the keys avoids creating unnecessary data structures
# and happily works with both python2 and python3 where the
# .keys() method is a dictionary_view in python3 and a list in python2.
def
_keys
(
self
):
yielded
=
set
()
for
k
in
self
.
keys
():
yielded
.
add
(
k
)
yield
k
if
self
.
parent
:
for
k
in
self
.
parent
.
_keys
():
if
k
not
in
yielded
:
yielded
.
add
(
k
)
yield
k
def
get
(
self
,
k
,
d
=
None
):
try
:
return
self
[
k
]
except
KeyError
:
return
d
def
get_color
(
self
,
k
,
d
=
None
):
"""
Strip leading # off colors if necessary
"""
color
=
self
.
get
(
k
,
d
)
if
hasattr
(
color
,
'startswith'
)
and
color
.
startswith
(
'#'
):
color
=
color
[
1
:]
if
len
(
color
)
==
3
:
# Premailers reduces colors like #00ff00 to #0f0, openpyxl doesn't like that
color
=
''
.
join
(
2
*
c
for
c
in
color
)
return
color
class
Element
(
object
):
"""
Our base class for representing an html element along with a cascading style.
The element is created along with a parent so that the StyleDict that we store
can point to the parent's StyleDict.
"""
def
__init__
(
self
,
element
,
parent
=
None
):
self
.
element
=
element
self
.
number_format
=
None
parent_style
=
parent
.
style_dict
if
parent
else
None
self
.
style_dict
=
StyleDict
(
style_string_to_dict
(
element
.
get
(
'style'
,
''
)),
parent
=
parent_style
)
self
.
_style_cache
=
None
def
style
(
self
):
"""
Turn the css styles for this element into an openpyxl NamedStyle.
"""
if
not
self
.
_style_cache
:
self
.
_style_cache
=
style_dict_to_named_style
(
self
.
style_dict
,
number_format
=
self
.
number_format
)
return
self
.
_style_cache
def
get_dimension
(
self
,
dimension_key
):
"""
Extracts the dimension from the style dict of the Element and returns it as a float.
"""
dimension
=
self
.
style_dict
.
get
(
dimension_key
)
if
dimension
:
if
dimension
[
-
2
:]
in
[
'px'
,
'em'
,
'pt'
,
'in'
,
'cm'
]:
dimension
=
dimension
[:
-
2
]
dimension
=
float
(
dimension
)
return
dimension
class
Table
(
Element
):
"""
The concrete implementations of Elements are semantically named for the types of elements we are interested in.
This defines a very concrete tree structure for html tables that we expect to deal with. I prefer this compared to
allowing Element to have an arbitrary number of children and dealing with an abstract element tree.
"""
def
__init__
(
self
,
table
):
"""
takes an html table object (from lxml)
"""
super
(
Table
,
self
).
__init__
(
table
)
table_head
=
table
.
find
(
'thead'
)
self
.
head
=
TableHead
(
table_head
,
parent
=
self
)
if
table_head
is
not
None
else
None
table_body
=
table
.
find
(
'tbody'
)
self
.
body
=
TableBody
(
table_body
if
table_body
is
not
None
else
table
,
parent
=
self
)
class
TableHead
(
Element
):
"""
This class maps to the `<th>` element of the html table.
"""
def
__init__
(
self
,
head
,
parent
=
None
):
super
(
TableHead
,
self
).
__init__
(
head
,
parent
=
parent
)
self
.
rows
=
[
TableRow
(
tr
,
parent
=
self
)
for
tr
in
head
.
findall
(
'tr'
)]
class
TableBody
(
Element
):
"""
This class maps to the `<tbody>` element of the html table.
"""
def
__init__
(
self
,
body
,
parent
=
None
):
super
(
TableBody
,
self
).
__init__
(
body
,
parent
=
parent
)
self
.
rows
=
[
TableRow
(
tr
,
parent
=
self
)
for
tr
in
body
.
findall
(
'tr'
)]
class
TableRow
(
Element
):
"""
This class maps to the `<tr>` element of the html table.
"""
def
__init__
(
self
,
tr
,
parent
=
None
):
super
(
TableRow
,
self
).
__init__
(
tr
,
parent
=
parent
)
self
.
cells
=
[
TableCell
(
cell
,
parent
=
self
)
for
cell
in
tr
.
findall
(
'th'
)
+
tr
.
findall
(
'td'
)]
def
element_to_string
(
el
):
return
_element_to_string
(
el
).
strip
()
def
_element_to_string
(
el
):
string
=
''
for
x
in
el
.
iterchildren
():
string
+=
'
\n
'
+
_element_to_string
(
x
)
text
=
el
.
text
.
strip
()
if
el
.
text
else
''
tail
=
el
.
tail
.
strip
()
if
el
.
tail
else
''
return
text
+
string
+
'
\n
'
+
tail
class
TableCell
(
Element
):
"""
This class maps to the `<td>` element of the html table.
"""
CELL_TYPES
=
{
'TYPE_STRING'
,
'TYPE_FORMULA'
,
'TYPE_NUMERIC'
,
'TYPE_BOOL'
,
'TYPE_CURRENCY'
,
'TYPE_PERCENTAGE'
,
'TYPE_NULL'
,
'TYPE_INLINE'
,
'TYPE_ERROR'
,
'TYPE_FORMULA_CACHE_STRING'
,
'TYPE_INTEGER'
}
def
__init__
(
self
,
cell
,
parent
=
None
):
super
(
TableCell
,
self
).
__init__
(
cell
,
parent
=
parent
)
self
.
value
=
element_to_string
(
cell
)
self
.
number_format
=
self
.
get_number_format
()
def
data_type
(
self
):
cell_types
=
self
.
CELL_TYPES
&
set
(
self
.
element
.
get
(
'class'
,
''
).
split
())
if
cell_types
:
if
'TYPE_FORMULA'
in
cell_types
:
# Make sure TYPE_FORMULA takes precedence over the other classes in the set.
cell_type
=
'TYPE_FORMULA'
elif
cell_types
&
{
'TYPE_CURRENCY'
,
'TYPE_INTEGER'
,
'TYPE_PERCENTAGE'
}:
cell_type
=
'TYPE_NUMERIC'
else
:
cell_type
=
cell_types
.
pop
()
else
:
cell_type
=
'TYPE_STRING'
return
getattr
(
cell
,
cell_type
)
def
get_number_format
(
self
):
if
'TYPE_CURRENCY'
in
self
.
element
.
get
(
'class'
,
''
).
split
():
return
FORMAT_CURRENCY_USD_SIMPLE
if
'TYPE_INTEGER'
in
self
.
element
.
get
(
'class'
,
''
).
split
():
return
'#,##0'
if
'TYPE_PERCENTAGE'
in
self
.
element
.
get
(
'class'
,
''
).
split
():
return
FORMAT_PERCENTAGE
if
'TYPE_DATE'
in
self
.
element
.
get
(
'class'
,
''
).
split
():
return
FORMAT_DATE_MMDDYYYY
if
self
.
data_type
()
==
cell
.
TYPE_NUMERIC
:
try
:
int
(
self
.
value
)
except
ValueError
:
return
'#,##0.##'
else
:
return
'#,##0'
def
format
(
self
,
cell
):
cell
.
style
=
self
.
style
()
data_type
=
self
.
data_type
()
if
data_type
:
cell
.
data_type
=
data_type
\ No newline at end of file
ppstructure/table/tablepyxl/tablepyxl.py
0 → 100644
View file @
037e17fc
# Do imports like python3 so our package works for 2 and 3
from
__future__
import
absolute_import
from
lxml
import
html
from
openpyxl
import
Workbook
from
openpyxl.utils
import
get_column_letter
from
premailer
import
Premailer
from
tablepyxl.style
import
Table
def
string_to_int
(
s
):
if
s
.
isdigit
():
return
int
(
s
)
return
0
def
get_Tables
(
doc
):
tree
=
html
.
fromstring
(
doc
)
comments
=
tree
.
xpath
(
'//comment()'
)
for
comment
in
comments
:
comment
.
drop_tag
()
return
[
Table
(
table
)
for
table
in
tree
.
xpath
(
'//table'
)]
def
write_rows
(
worksheet
,
elem
,
row
,
column
=
1
):
"""
Writes every tr child element of elem to a row in the worksheet
returns the next row after all rows are written
"""
from
openpyxl.cell.cell
import
MergedCell
initial_column
=
column
for
table_row
in
elem
.
rows
:
for
table_cell
in
table_row
.
cells
:
cell
=
worksheet
.
cell
(
row
=
row
,
column
=
column
)
while
isinstance
(
cell
,
MergedCell
):
column
+=
1
cell
=
worksheet
.
cell
(
row
=
row
,
column
=
column
)
colspan
=
string_to_int
(
table_cell
.
element
.
get
(
"colspan"
,
"1"
))
rowspan
=
string_to_int
(
table_cell
.
element
.
get
(
"rowspan"
,
"1"
))
if
rowspan
>
1
or
colspan
>
1
:
worksheet
.
merge_cells
(
start_row
=
row
,
start_column
=
column
,
end_row
=
row
+
rowspan
-
1
,
end_column
=
column
+
colspan
-
1
)
cell
.
value
=
table_cell
.
value
table_cell
.
format
(
cell
)
min_width
=
table_cell
.
get_dimension
(
'min-width'
)
max_width
=
table_cell
.
get_dimension
(
'max-width'
)
if
colspan
==
1
:
# Initially, when iterating for the first time through the loop, the width of all the cells is None.
# As we start filling in contents, the initial width of the cell (which can be retrieved by:
# worksheet.column_dimensions[get_column_letter(column)].width) is equal to the width of the previous
# cell in the same column (i.e. width of A2 = width of A1)
width
=
max
(
worksheet
.
column_dimensions
[
get_column_letter
(
column
)].
width
or
0
,
len
(
table_cell
.
value
)
+
2
)
if
max_width
and
width
>
max_width
:
width
=
max_width
elif
min_width
and
width
<
min_width
:
width
=
min_width
worksheet
.
column_dimensions
[
get_column_letter
(
column
)].
width
=
width
column
+=
colspan
row
+=
1
column
=
initial_column
return
row
def
table_to_sheet
(
table
,
wb
):
"""
Takes a table and workbook and writes the table to a new sheet.
The sheet title will be the same as the table attribute name.
"""
ws
=
wb
.
create_sheet
(
title
=
table
.
element
.
get
(
'name'
))
insert_table
(
table
,
ws
,
1
,
1
)
def
document_to_workbook
(
doc
,
wb
=
None
,
base_url
=
None
):
"""
Takes a string representation of an html document and writes one sheet for
every table in the document.
The workbook is returned
"""
if
not
wb
:
wb
=
Workbook
()
wb
.
remove
(
wb
.
active
)
inline_styles_doc
=
Premailer
(
doc
,
base_url
=
base_url
,
remove_classes
=
False
).
transform
()
tables
=
get_Tables
(
inline_styles_doc
)
for
table
in
tables
:
table_to_sheet
(
table
,
wb
)
return
wb
def
document_to_xl
(
doc
,
filename
,
base_url
=
None
):
"""
Takes a string representation of an html document and writes one sheet for
every table in the document. The workbook is written out to a file called filename
"""
wb
=
document_to_workbook
(
doc
,
base_url
=
base_url
)
wb
.
save
(
filename
)
def
insert_table
(
table
,
worksheet
,
column
,
row
):
if
table
.
head
:
row
=
write_rows
(
worksheet
,
table
.
head
,
row
,
column
)
if
table
.
body
:
row
=
write_rows
(
worksheet
,
table
.
body
,
row
,
column
)
def
insert_table_at_cell
(
table
,
cell
):
"""
Inserts a table at the location of an openpyxl Cell object.
"""
ws
=
cell
.
parent
column
,
row
=
cell
.
column
,
cell
.
row
insert_table
(
table
,
ws
,
column
,
row
)
\ No newline at end of file
ppstructure/utility.py
0 → 100644
View file @
037e17fc
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
PIL
import
Image
import
numpy
as
np
from
tools.infer.utility
import
draw_ocr_box_txt
,
init_args
as
infer_args
def
init_args
():
parser
=
infer_args
()
# params for output
parser
.
add_argument
(
"--output"
,
type
=
str
,
default
=
'./output/table'
)
# params for table structure
parser
.
add_argument
(
"--structure_max_len"
,
type
=
int
,
default
=
488
)
parser
.
add_argument
(
"--structure_max_text_length"
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
"--structure_max_elem_length"
,
type
=
int
,
default
=
800
)
parser
.
add_argument
(
"--structure_max_cell_num"
,
type
=
int
,
default
=
500
)
parser
.
add_argument
(
"--structure_model_dir"
,
type
=
str
)
parser
.
add_argument
(
"--structure_char_type"
,
type
=
str
,
default
=
'en'
)
parser
.
add_argument
(
"--structure_char_dict_path"
,
type
=
str
,
default
=
"../ppocr/utils/dict/table_structure_dict.txt"
)
# params for layout detector
parser
.
add_argument
(
"--layout_model_dir"
,
type
=
str
)
return
parser
def
parse_args
():
parser
=
init_args
()
return
parser
.
parse_args
()
def
draw_result
(
image
,
result
,
font_path
):
if
isinstance
(
image
,
np
.
ndarray
):
image
=
Image
.
fromarray
(
image
)
boxes
,
txts
,
scores
=
[],
[],
[]
for
region
in
result
:
if
region
[
'type'
]
==
'Table'
:
pass
elif
region
[
'type'
]
==
'Figure'
:
pass
else
:
for
box
,
rec_res
in
zip
(
region
[
'res'
][
0
],
region
[
'res'
][
1
]):
boxes
.
append
(
np
.
array
(
box
).
reshape
(
-
1
,
2
))
txts
.
append
(
rec_res
[
0
])
scores
.
append
(
rec_res
[
1
])
im_show
=
draw_ocr_box_txt
(
image
,
boxes
,
txts
,
scores
,
font_path
=
font_path
,
drop_score
=
0
)
return
im_show
\ No newline at end of file
tools/infer/predict_det.py
View file @
037e17fc
...
...
@@ -43,7 +43,7 @@ class TextDetector(object):
pre_process_list
=
[{
'DetResizeForTest'
:
{
'limit_side_len'
:
args
.
det_limit_side_len
,
'limit_type'
:
args
.
det_limit_type
'limit_type'
:
args
.
det_limit_type
,
}
},
{
'NormalizeImage'
:
{
...
...
tools/infer/utility.py
View file @
037e17fc
...
...
@@ -109,11 +109,12 @@ def init_args():
parser
.
add_argument
(
"--use_mp"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--total_process_num"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--process_id"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--benchmark"
,
type
=
bool
,
default
=
False
)
parser
.
add_argument
(
"--save_log_path"
,
type
=
str
,
default
=
"./log_output/"
)
parser
.
add_argument
(
"--show_log"
,
type
=
str2bool
,
default
=
True
)
return
parser
...
...
@@ -199,6 +200,8 @@ def create_predictor(args, mode, logger):
model_dir
=
args
.
cls_model_dir
elif
mode
==
'rec'
:
model_dir
=
args
.
rec_model_dir
elif
mode
==
'structure'
:
model_dir
=
args
.
structure_model_dir
else
:
model_dir
=
args
.
e2e_model_dir
...
...
@@ -328,7 +331,9 @@ def create_predictor(args, mode, logger):
config
.
delete_pass
(
"conv_transpose_eltwiseadd_bn_fuse_pass"
)
config
.
switch_use_feed_fetch_ops
(
False
)
config
.
switch_ir_optim
(
True
)
if
mode
==
'structure'
:
config
.
switch_ir_optim
(
False
)
# create predictor
predictor
=
inference
.
create_predictor
(
config
)
input_names
=
predictor
.
get_input_names
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment