Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
2502db13
Unverified
Commit
2502db13
authored
Aug 09, 2024
by
Xiaomeng Zhao
Committed by
GitHub
Aug 09, 2024
Browse files
Merge pull request #374 from myhloli/master
fix&refactor(pdf-extract-kit): table recognition and ocr
parents
ad5596fc
334ccac2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
62 additions
and
67 deletions
+62
-67
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+62
-67
No files found.
magic_pdf/model/pdf_extract_kit.py
View file @
2502db13
...
@@ -27,7 +27,7 @@ except ImportError as e:
...
@@ -27,7 +27,7 @@ except ImportError as e:
logger
.
exception
(
e
)
logger
.
exception
(
e
)
logger
.
error
(
logger
.
error
(
'Required dependency not installed, please install by
\n
'
'Required dependency not installed, please install by
\n
'
'"pip install magic-pdf[full]
detectron2
--extra-index-url https://myhloli.github.io/wheels/"'
)
'"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"'
)
exit
(
1
)
exit
(
1
)
from
magic_pdf.model.pek_sub_modules.layoutlmv3.model_init
import
Layoutlmv3_Predictor
from
magic_pdf.model.pek_sub_modules.layoutlmv3.model_init
import
Layoutlmv3_Predictor
...
@@ -188,13 +188,9 @@ class CustomPEKModel:
...
@@ -188,13 +188,9 @@ class CustomPEKModel:
mfr_cost
=
round
(
time
.
time
()
-
mfr_start
,
2
)
mfr_cost
=
round
(
time
.
time
()
-
mfr_start
,
2
)
logger
.
info
(
f
"formula nums:
{
len
(
mf_image_list
)
}
, mfr time:
{
mfr_cost
}
"
)
logger
.
info
(
f
"formula nums:
{
len
(
mf_image_list
)
}
, mfr time:
{
mfr_cost
}
"
)
# ocr识别
# Select regions for OCR / formula regions / table regions
if
self
.
apply_ocr
:
ocr_start
=
time
.
time
()
pil_img
=
Image
.
fromarray
(
image
)
# 筛选出需要OCR的区域和公式区域
ocr_res_list
=
[]
ocr_res_list
=
[]
table_res_list
=
[]
single_page_mfdetrec_res
=
[]
single_page_mfdetrec_res
=
[]
for
res
in
layout_res
:
for
res
in
layout_res
:
if
int
(
res
[
'category_id'
])
in
[
13
,
14
]:
if
int
(
res
[
'category_id'
])
in
[
13
,
14
]:
...
@@ -204,34 +200,44 @@ class CustomPEKModel:
...
@@ -204,34 +200,44 @@ class CustomPEKModel:
})
})
elif
int
(
res
[
'category_id'
])
in
[
0
,
1
,
2
,
4
,
6
,
7
]:
elif
int
(
res
[
'category_id'
])
in
[
0
,
1
,
2
,
4
,
6
,
7
]:
ocr_res_list
.
append
(
res
)
ocr_res_list
.
append
(
res
)
elif
int
(
res
[
'category_id'
])
in
[
5
]:
table_res_list
.
append
(
res
)
# Unified crop img logic
def
crop_img
(
input_res
,
input_pil_img
,
crop_paste_x
=
0
,
crop_paste_y
=
0
):
crop_xmin
,
crop_ymin
=
int
(
input_res
[
'poly'
][
0
]),
int
(
input_res
[
'poly'
][
1
])
crop_xmax
,
crop_ymax
=
int
(
input_res
[
'poly'
][
4
]),
int
(
input_res
[
'poly'
][
5
])
# Create a white background with an additional width and height of 50
crop_new_width
=
crop_xmax
-
crop_xmin
+
crop_paste_x
*
2
crop_new_height
=
crop_ymax
-
crop_ymin
+
crop_paste_y
*
2
return_image
=
Image
.
new
(
'RGB'
,
(
crop_new_width
,
crop_new_height
),
'white'
)
# Crop image
crop_box
=
(
crop_xmin
,
crop_ymin
,
crop_xmax
,
crop_ymax
)
cropped_img
=
input_pil_img
.
crop
(
crop_box
)
return_image
.
paste
(
cropped_img
,
(
crop_paste_x
,
crop_paste_y
))
return_list
=
[
crop_paste_x
,
crop_paste_y
,
crop_xmin
,
crop_ymin
,
crop_xmax
,
crop_ymax
,
crop_new_width
,
crop_new_height
]
return
return_image
,
return_list
pil_img
=
Image
.
fromarray
(
image
)
# 对每一个需OCR处理的区域进行处理
# ocr识别
if
self
.
apply_ocr
:
ocr_start
=
time
.
time
()
# Process each area that requires OCR processing
for
res
in
ocr_res_list
:
for
res
in
ocr_res_list
:
xmin
,
ymin
=
int
(
res
[
'poly'
][
0
]),
int
(
res
[
'poly'
][
1
])
new_image
,
useful_list
=
crop_img
(
res
,
pil_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
xmax
,
ymax
=
int
(
res
[
'poly'
][
4
]),
int
(
res
[
'poly'
][
5
])
paste_x
,
paste_y
,
xmin
,
ymin
,
xmax
,
ymax
,
new_width
,
new_height
=
useful_list
# Adjust the coordinates of the formula area
paste_x
=
50
paste_y
=
50
# 创建一个宽高各多50的白色背景
new_width
=
xmax
-
xmin
+
paste_x
*
2
new_height
=
ymax
-
ymin
+
paste_y
*
2
new_image
=
Image
.
new
(
'RGB'
,
(
new_width
,
new_height
),
'white'
)
# 裁剪图像
crop_box
=
(
xmin
,
ymin
,
xmax
,
ymax
)
cropped_img
=
pil_img
.
crop
(
crop_box
)
new_image
.
paste
(
cropped_img
,
(
paste_x
,
paste_y
))
# 调整公式区域坐标
adjusted_mfdetrec_res
=
[]
adjusted_mfdetrec_res
=
[]
for
mf_res
in
single_page_mfdetrec_res
:
for
mf_res
in
single_page_mfdetrec_res
:
mf_xmin
,
mf_ymin
,
mf_xmax
,
mf_ymax
=
mf_res
[
"bbox"
]
mf_xmin
,
mf_ymin
,
mf_xmax
,
mf_ymax
=
mf_res
[
"bbox"
]
#
将公式区域坐标调整为相对于裁剪区域的坐标
#
Adjust the coordinates of the formula area to the coordinates relative to the cropping area
x0
=
mf_xmin
-
xmin
+
paste_x
x0
=
mf_xmin
-
xmin
+
paste_x
y0
=
mf_ymin
-
ymin
+
paste_y
y0
=
mf_ymin
-
ymin
+
paste_y
x1
=
mf_xmax
-
xmin
+
paste_x
x1
=
mf_xmax
-
xmin
+
paste_x
y1
=
mf_ymax
-
ymin
+
paste_y
y1
=
mf_ymax
-
ymin
+
paste_y
#
过滤在图外的公式块
#
Filter formula blocks outside the graph
if
any
([
x1
<
0
,
y1
<
0
])
or
any
([
x0
>
new_width
,
y0
>
new_height
]):
if
any
([
x1
<
0
,
y1
<
0
])
or
any
([
x0
>
new_width
,
y0
>
new_height
]):
continue
continue
else
:
else
:
...
@@ -239,17 +245,17 @@ class CustomPEKModel:
...
@@ -239,17 +245,17 @@ class CustomPEKModel:
"bbox"
:
[
x0
,
y0
,
x1
,
y1
],
"bbox"
:
[
x0
,
y0
,
x1
,
y1
],
})
})
# OCR
识别
# OCR
recognition
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
),
cv2
.
COLOR_RGB2BGR
)
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
),
cv2
.
COLOR_RGB2BGR
)
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
#
整合结果
#
Integration results
if
ocr_res
:
if
ocr_res
:
for
box_ocr_res
in
ocr_res
:
for
box_ocr_res
in
ocr_res
:
p1
,
p2
,
p3
,
p4
=
box_ocr_res
[
0
]
p1
,
p2
,
p3
,
p4
=
box_ocr_res
[
0
]
text
,
score
=
box_ocr_res
[
1
]
text
,
score
=
box_ocr_res
[
1
]
#
将坐标转换回原图坐标系
#
Convert the coordinates back to the original coordinate system
p1
=
[
p1
[
0
]
-
paste_x
+
xmin
,
p1
[
1
]
-
paste_y
+
ymin
]
p1
=
[
p1
[
0
]
-
paste_x
+
xmin
,
p1
[
1
]
-
paste_y
+
ymin
]
p2
=
[
p2
[
0
]
-
paste_x
+
xmin
,
p2
[
1
]
-
paste_y
+
ymin
]
p2
=
[
p2
[
0
]
-
paste_x
+
xmin
,
p2
[
1
]
-
paste_y
+
ymin
]
p3
=
[
p3
[
0
]
-
paste_x
+
xmin
,
p3
[
1
]
-
paste_y
+
ymin
]
p3
=
[
p3
[
0
]
-
paste_x
+
xmin
,
p3
[
1
]
-
paste_y
+
ymin
]
...
@@ -267,35 +273,24 @@ class CustomPEKModel:
...
@@ -267,35 +273,24 @@ class CustomPEKModel:
# 表格识别 table recognition
# 表格识别 table recognition
if
self
.
apply_table
:
if
self
.
apply_table
:
pil_img
=
Image
.
fromarray
(
image
)
table_start
=
time
.
time
()
for
layout
in
layout_res
:
for
res
in
table_res_list
:
if
layout
.
get
(
"category_id"
,
-
1
)
==
5
:
new_image
,
_
=
crop_img
(
res
,
pil_img
)
poly
=
layout
[
"poly"
]
single_table_start_time
=
time
.
time
()
xmin
,
ymin
=
int
(
poly
[
0
]),
int
(
poly
[
1
])
xmax
,
ymax
=
int
(
poly
[
4
]),
int
(
poly
[
5
])
paste_x
=
50
paste_y
=
50
# 创建一个宽高各多50的白色背景 create a whiteboard with 50 larger width and length
new_width
=
xmax
-
xmin
+
paste_x
*
2
new_height
=
ymax
-
ymin
+
paste_y
*
2
new_image
=
Image
.
new
(
'RGB'
,
(
new_width
,
new_height
),
'white'
)
# 裁剪图像 crop image
crop_box
=
(
xmin
,
ymin
,
xmax
,
ymax
)
cropped_img
=
pil_img
.
crop
(
crop_box
)
new_image
.
paste
(
cropped_img
,
(
paste_x
,
paste_y
))
start_time
=
time
.
time
()
logger
.
info
(
"------------------table recognition processing begins-----------------"
)
logger
.
info
(
"------------------table recognition processing begins-----------------"
)
with
torch
.
no_grad
():
latex_code
=
self
.
table_model
.
image2latex
(
new_image
)[
0
]
latex_code
=
self
.
table_model
.
image2latex
(
new_image
)[
0
]
end_time
=
time
.
time
()
run_time
=
time
.
time
()
-
single_table_start_time
run_time
=
end_time
-
start_time
logger
.
info
(
f
"------------table recognition processing ends within
{
run_time
}
s-----"
)
logger
.
info
(
f
"------------table recognition processing ends within
{
run_time
}
s-----"
)
if
run_time
>
self
.
table_max_time
:
if
run_time
>
self
.
table_max_time
:
logger
.
warning
(
f
"------------table recognition processing exceeds max time
{
self
.
table_max_time
}
s----------"
)
logger
.
warning
(
f
"------------table recognition processing exceeds max time
{
self
.
table_max_time
}
s----------"
)
# 判断是否返回正常
# 判断是否返回正常
if
latex_code
and
latex_code
.
strip
().
endswith
(
'end{tabular}'
):
expected_ending
=
latex_code
.
strip
().
endswith
(
'end{tabular}'
)
or
latex_code
.
strip
().
endswith
(
'end{table}'
)
layout
[
"latex"
]
=
latex_code
if
latex_code
and
expected_ending
:
res
[
"latex"
]
=
latex_code
else
:
else
:
logger
.
warning
(
f
"------------table recognition processing fails----------"
)
logger
.
warning
(
f
"------------table recognition processing fails----------"
)
table_cost
=
round
(
time
.
time
()
-
table_start
,
2
)
logger
.
info
(
f
"table cost:
{
table_cost
}
"
)
return
layout_res
return
layout_res
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment