Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
190e2231
Unverified
Commit
190e2231
authored
Nov 21, 2024
by
Xiaomeng Zhao
Committed by
GitHub
Nov 21, 2024
Browse files
Merge pull request #1049 from myhloli/dev
feat(ocr): improve text detection and OCR accuracy
parents
a703e527
e52bd023
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
36 deletions
+49
-36
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+29
-30
magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py
magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py
+10
-5
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py
+10
-1
No files found.
magic_pdf/model/pdf_extract_kit.py
View file @
190e2231
...
@@ -30,6 +30,7 @@ from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
...
@@ -30,6 +30,7 @@ from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
class
CustomPEKModel
:
class
CustomPEKModel
:
def
__init__
(
self
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
**
kwargs
):
def
__init__
(
self
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
**
kwargs
):
"""
"""
======== model init ========
======== model init ========
...
@@ -149,13 +150,12 @@ class CustomPEKModel:
...
@@ -149,13 +150,12 @@ class CustomPEKModel:
device
=
self
.
device
,
device
=
self
.
device
,
)
)
# 初始化ocr
# 初始化ocr
if
self
.
apply_ocr
:
self
.
ocr_model
=
atom_model_manager
.
get_atom_model
(
self
.
ocr_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
OCR
,
atom_model_name
=
AtomicModel
.
OCR
,
ocr_show_log
=
show_log
,
ocr_show_log
=
show_log
,
det_db_box_thresh
=
0.3
,
det_db_box_thresh
=
0.3
,
lang
=
self
.
lang
lang
=
self
.
lang
,
)
)
# init table model
# init table model
if
self
.
apply_table
:
if
self
.
apply_table
:
table_model_dir
=
self
.
configs
[
'weights'
][
self
.
table_model_name
]
table_model_dir
=
self
.
configs
[
'weights'
][
self
.
table_model_name
]
...
@@ -208,30 +208,29 @@ class CustomPEKModel:
...
@@ -208,30 +208,29 @@ class CustomPEKModel:
)
)
# ocr识别
# ocr识别
ocr_start
=
time
.
time
()
# Process each area that requires OCR processing
for
res
in
ocr_res_list
:
new_image
,
useful_list
=
crop_img
(
res
,
pil_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
# OCR recognition
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
),
cv2
.
COLOR_RGB2BGR
)
if
self
.
apply_ocr
:
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
else
:
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
,
rec
=
False
)[
0
]
# Integration results
if
ocr_res
:
ocr_result_list
=
get_ocr_result_list
(
ocr_res
,
useful_list
)
layout_res
.
extend
(
ocr_result_list
)
ocr_cost
=
round
(
time
.
time
()
-
ocr_start
,
2
)
if
self
.
apply_ocr
:
if
self
.
apply_ocr
:
ocr_start
=
time
.
time
()
logger
.
info
(
f
"ocr time:
{
ocr_cost
}
"
)
# Process each area that requires OCR processing
else
:
for
res
in
ocr_res_list
:
logger
.
info
(
f
"det time:
{
ocr_cost
}
"
)
new_image
,
useful_list
=
crop_img
(
res
,
pil_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
# OCR recognition
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
),
cv2
.
COLOR_RGB2BGR
)
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
# Integration results
if
ocr_res
:
ocr_result_list
=
get_ocr_result_list
(
ocr_res
,
useful_list
)
layout_res
.
extend
(
ocr_result_list
)
ocr_cost
=
round
(
time
.
time
()
-
ocr_start
,
2
)
logger
.
info
(
f
'ocr time:
{
ocr_cost
}
'
)
# 表格识别 table recognition
# 表格识别 table recognition
if
self
.
apply_table
:
if
self
.
apply_table
:
...
...
magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py
View file @
190e2231
...
@@ -211,16 +211,21 @@ def get_ocr_result_list(ocr_res, useful_list):
...
@@ -211,16 +211,21 @@ def get_ocr_result_list(ocr_res, useful_list):
ocr_result_list
=
[]
ocr_result_list
=
[]
for
box_ocr_res
in
ocr_res
:
for
box_ocr_res
in
ocr_res
:
p1
,
p2
,
p3
,
p4
=
box_ocr_res
[
0
]
if
len
(
box_ocr_res
)
==
2
:
text
,
score
=
box_ocr_res
[
1
]
p1
,
p2
,
p3
,
p4
=
box_ocr_res
[
0
]
text
,
score
=
box_ocr_res
[
1
]
else
:
p1
,
p2
,
p3
,
p4
=
box_ocr_res
text
,
score
=
""
,
1
# average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
# average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
# if average_angle_degrees > 0.5:
# if average_angle_degrees > 0.5:
if
calculate_is_angle
(
box_ocr_res
[
0
]):
poly
=
[
p1
,
p2
,
p3
,
p4
]
if
calculate_is_angle
(
poly
):
# logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
# logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
# 与x轴的夹角超过0.5度,对边界做一下矫正
# 与x轴的夹角超过0.5度,对边界做一下矫正
# 计算几何中心
# 计算几何中心
x_center
=
sum
(
point
[
0
]
for
point
in
box_ocr_res
[
0
]
)
/
4
x_center
=
sum
(
point
[
0
]
for
point
in
poly
)
/
4
y_center
=
sum
(
point
[
1
]
for
point
in
box_ocr_res
[
0
]
)
/
4
y_center
=
sum
(
point
[
1
]
for
point
in
poly
)
/
4
new_height
=
((
p4
[
1
]
-
p1
[
1
])
+
(
p3
[
1
]
-
p2
[
1
]))
/
2
new_height
=
((
p4
[
1
]
-
p1
[
1
])
+
(
p3
[
1
]
-
p2
[
1
]))
/
2
new_width
=
p3
[
0
]
-
p1
[
0
]
new_width
=
p3
[
0
]
-
p1
[
0
]
p1
=
[
x_center
-
new_width
/
2
,
y_center
-
new_height
/
2
]
p1
=
[
x_center
-
new_width
/
2
,
y_center
-
new_height
/
2
]
...
...
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py
View file @
190e2231
...
@@ -78,9 +78,18 @@ class ModifiedPaddleOCR(PaddleOCR):
...
@@ -78,9 +78,18 @@ class ModifiedPaddleOCR(PaddleOCR):
for
idx
,
img
in
enumerate
(
imgs
):
for
idx
,
img
in
enumerate
(
imgs
):
img
=
preprocess_image
(
img
)
img
=
preprocess_image
(
img
)
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
dt_boxes
,
elapse
=
self
.
text_detector
(
img
)
if
not
dt_boxes
:
if
dt_boxes
is
None
:
ocr_res
.
append
(
None
)
ocr_res
.
append
(
None
)
continue
continue
dt_boxes
=
sorted_boxes
(
dt_boxes
)
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
dt_boxes
=
merge_det_boxes
(
dt_boxes
)
if
mfd_res
:
bef
=
time
.
time
()
dt_boxes
=
update_det_boxes
(
dt_boxes
,
mfd_res
)
aft
=
time
.
time
()
logger
.
debug
(
"split text box by formula, new dt_boxes num : {}, elapsed : {}"
.
format
(
len
(
dt_boxes
),
aft
-
bef
))
tmp_res
=
[
box
.
tolist
()
for
box
in
dt_boxes
]
tmp_res
=
[
box
.
tolist
()
for
box
in
dt_boxes
]
ocr_res
.
append
(
tmp_res
)
ocr_res
.
append
(
tmp_res
)
return
ocr_res
return
ocr_res
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment