Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
6d571e2e
Unverified
Commit
6d571e2e
authored
Oct 28, 2024
by
Kaiwen Liu
Committed by
GitHub
Oct 28, 2024
Browse files
Merge pull request #7 from opendatalab/dev
Dev
parents
a3358878
37c335ae
Changes
123
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1188 additions
and
345 deletions
+1188
-345
magic_pdf/libs/ocr_content_type.py
magic_pdf/libs/ocr_content_type.py
+2
-0
magic_pdf/model/doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+71
-31
magic_pdf/model/magic_model.py
magic_pdf/model/magic_model.py
+265
-22
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+86
-47
magic_pdf/model/ppTableModel.py
magic_pdf/model/ppTableModel.py
+2
-2
magic_pdf/para/para_split_v3.py
magic_pdf/para/para_split_v3.py
+296
-0
magic_pdf/pdf_parse_by_ocr.py
magic_pdf/pdf_parse_by_ocr.py
+5
-2
magic_pdf/pdf_parse_by_txt.py
magic_pdf/pdf_parse_by_txt.py
+5
-2
magic_pdf/pdf_parse_union_core_v2.py
magic_pdf/pdf_parse_union_core_v2.py
+311
-165
magic_pdf/pipe/AbsPipe.py
magic_pdf/pipe/AbsPipe.py
+6
-7
magic_pdf/pipe/OCRPipe.py
magic_pdf/pipe/OCRPipe.py
+8
-4
magic_pdf/pipe/TXTPipe.py
magic_pdf/pipe/TXTPipe.py
+8
-4
magic_pdf/pipe/UNIPipe.py
magic_pdf/pipe/UNIPipe.py
+10
-5
magic_pdf/pre_proc/ocr_detect_all_bboxes.py
magic_pdf/pre_proc/ocr_detect_all_bboxes.py
+56
-27
magic_pdf/pre_proc/ocr_dict_merge.py
magic_pdf/pre_proc/ocr_dict_merge.py
+27
-2
magic_pdf/resources/model_config/model_configs.yaml
magic_pdf/resources/model_config/model_configs.yaml
+5
-13
magic_pdf/tools/cli.py
magic_pdf/tools/cli.py
+1
-1
magic_pdf/tools/common.py
magic_pdf/tools/common.py
+11
-6
magic_pdf/user_api.py
magic_pdf/user_api.py
+13
-5
magic_pdf/utils/__init__.py
magic_pdf/utils/__init__.py
+0
-0
No files found.
magic_pdf/libs/ocr_content_type.py
View file @
6d571e2e
...
...
@@ -20,6 +20,8 @@ class BlockType:
InterlineEquation
=
'interline_equation'
Footnote
=
'footnote'
Discarded
=
'discarded'
List
=
'list'
Index
=
'index'
class
CategoryId
:
...
...
magic_pdf/model/doc_analyze_by_custom_model.py
View file @
6d571e2e
...
...
@@ -4,7 +4,9 @@ import fitz
import
numpy
as
np
from
loguru
import
logger
from
magic_pdf.libs.config_reader
import
get_local_models_dir
,
get_device
,
get_table_recog_config
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.libs.config_reader
import
get_local_models_dir
,
get_device
,
get_table_recog_config
,
get_layout_config
,
\
get_formula_config
from
magic_pdf.model.model_list
import
MODEL
import
magic_pdf.model
as
model_config
...
...
@@ -23,7 +25,7 @@ def remove_duplicates_dicts(lst):
return
unique_dicts
def
load_images_from_pdf
(
pdf_bytes
:
bytes
,
dpi
=
200
)
->
list
:
def
load_images_from_pdf
(
pdf_bytes
:
bytes
,
dpi
=
200
,
start_page_id
=
0
,
end_page_id
=
None
)
->
list
:
try
:
from
PIL
import
Image
except
ImportError
:
...
...
@@ -32,18 +34,28 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
images
=
[]
with
fitz
.
open
(
"pdf"
,
pdf_bytes
)
as
doc
:
pdf_page_num
=
doc
.
page_count
end_page_id
=
end_page_id
if
end_page_id
is
not
None
and
end_page_id
>=
0
else
pdf_page_num
-
1
if
end_page_id
>
pdf_page_num
-
1
:
logger
.
warning
(
"end_page_id is out of range, use images length"
)
end_page_id
=
pdf_page_num
-
1
for
index
in
range
(
0
,
doc
.
page_count
):
page
=
doc
[
index
]
mat
=
fitz
.
Matrix
(
dpi
/
72
,
dpi
/
72
)
pm
=
page
.
get_pixmap
(
matrix
=
mat
,
alpha
=
False
)
if
start_page_id
<=
index
<=
end_page_id
:
page
=
doc
[
index
]
mat
=
fitz
.
Matrix
(
dpi
/
72
,
dpi
/
72
)
pm
=
page
.
get_pixmap
(
matrix
=
mat
,
alpha
=
False
)
# If the width or height exceeds 9000 after scaling, do not scale further.
if
pm
.
width
>
9000
or
pm
.
height
>
9000
:
pm
=
page
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
# If the width or height exceeds 9000 after scaling, do not scale further.
if
pm
.
width
>
9000
or
pm
.
height
>
9000
:
pm
=
page
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
img
=
Image
.
frombytes
(
"RGB"
,
(
pm
.
width
,
pm
.
height
),
pm
.
samples
)
img
=
np
.
array
(
img
)
img_dict
=
{
"img"
:
img
,
"width"
:
pm
.
width
,
"height"
:
pm
.
height
}
else
:
img_dict
=
{
"img"
:
[],
"width"
:
0
,
"height"
:
0
}
img
=
Image
.
frombytes
(
"RGB"
,
(
pm
.
width
,
pm
.
height
),
pm
.
samples
)
img
=
np
.
array
(
img
)
img_dict
=
{
"img"
:
img
,
"width"
:
pm
.
width
,
"height"
:
pm
.
height
}
images
.
append
(
img_dict
)
return
images
...
...
@@ -57,14 +69,17 @@ class ModelSingleton:
cls
.
_instance
=
super
().
__new__
(
cls
)
return
cls
.
_instance
def
get_model
(
self
,
ocr
:
bool
,
show_log
:
bool
,
lang
=
None
):
key
=
(
ocr
,
show_log
,
lang
)
def
get_model
(
self
,
ocr
:
bool
,
show_log
:
bool
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
key
=
(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
if
key
not
in
self
.
_models
:
self
.
_models
[
key
]
=
custom_model_init
(
ocr
=
ocr
,
show_log
=
show_log
,
lang
=
lang
)
self
.
_models
[
key
]
=
custom_model_init
(
ocr
=
ocr
,
show_log
=
show_log
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
return
self
.
_models
[
key
]
def
custom_model_init
(
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
lang
=
None
):
def
custom_model_init
(
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
model
=
None
if
model_config
.
__model_mode__
==
"lite"
:
...
...
@@ -84,14 +99,30 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
# 从配置文件读取model-dir和device
local_models_dir
=
get_local_models_dir
()
device
=
get_device
()
layout_config
=
get_layout_config
()
if
layout_model
is
not
None
:
layout_config
[
"model"
]
=
layout_model
formula_config
=
get_formula_config
()
if
formula_enable
is
not
None
:
formula_config
[
"enable"
]
=
formula_enable
table_config
=
get_table_recog_config
()
model_input
=
{
"ocr"
:
ocr
,
"show_log"
:
show_log
,
"models_dir"
:
local_models_dir
,
"device"
:
device
,
"table_config"
:
table_config
,
"lang"
:
lang
,
}
if
table_enable
is
not
None
:
table_config
[
"enable"
]
=
table_enable
model_input
=
{
"ocr"
:
ocr
,
"show_log"
:
show_log
,
"models_dir"
:
local_models_dir
,
"device"
:
device
,
"table_config"
:
table_config
,
"layout_config"
:
layout_config
,
"formula_config"
:
formula_config
,
"lang"
:
lang
,
}
custom_model
=
CustomPEKModel
(
**
model_input
)
else
:
logger
.
error
(
"Not allow model_name!"
)
...
...
@@ -106,19 +137,23 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
def
doc_analyze
(
pdf_bytes
:
bytes
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
model_manager
=
ModelSingleton
()
custom_model
=
model_manager
.
get_model
(
ocr
,
show_log
,
lang
)
if
lang
==
""
:
lang
=
None
images
=
load_images_from_pdf
(
pdf_bytes
)
model_manager
=
ModelSingleton
()
custom_model
=
model_manager
.
get_model
(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
# end_page_id = end_page_id if end_page_id else len(images) - 1
end_page_id
=
end_page_id
if
end_page_id
is
not
None
and
end_page_id
>=
0
else
len
(
images
)
-
1
with
fitz
.
open
(
"pdf"
,
pdf_bytes
)
as
doc
:
pdf_page_num
=
doc
.
page_count
end_page_id
=
end_page_id
if
end_page_id
is
not
None
and
end_page_id
>=
0
else
pdf_page_num
-
1
if
end_page_id
>
pdf_page_num
-
1
:
logger
.
warning
(
"end_page_id is out of range, use images length"
)
end_page_id
=
pdf_page_num
-
1
if
end_page_id
>
len
(
images
)
-
1
:
logger
.
warning
(
"end_page_id is out of range, use images length"
)
end_page_id
=
len
(
images
)
-
1
images
=
load_images_from_pdf
(
pdf_bytes
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
)
model_json
=
[]
doc_analyze_start
=
time
.
time
()
...
...
@@ -135,6 +170,11 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
page_dict
=
{
"layout_dets"
:
result
,
"page_info"
:
page_info
}
model_json
.
append
(
page_dict
)
gc_start
=
time
.
time
()
clean_memory
()
gc_time
=
round
(
time
.
time
()
-
gc_start
,
2
)
logger
.
info
(
f
"gc time:
{
gc_time
}
"
)
doc_analyze_time
=
round
(
time
.
time
()
-
doc_analyze_start
,
2
)
doc_analyze_speed
=
round
(
(
end_page_id
+
1
-
start_page_id
)
/
doc_analyze_time
,
2
)
logger
.
info
(
f
"doc analyze time:
{
round
(
time
.
time
()
-
doc_analyze_start
,
2
)
}
,"
...
...
magic_pdf/model/magic_model.py
View file @
6d571e2e
import
json
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.boxbase
import
(
_is_in
,
_is_part_overlap
,
bbox_distance
,
bbox_relative_pos
,
box_area
,
calculate_iou
,
calculate_overlap_area_in_bbox1_area_ratio
,
...
...
@@ -9,6 +10,7 @@ from magic_pdf.libs.coordinate_transform import get_scale_ratio
from
magic_pdf.libs.local_math
import
float_gt
from
magic_pdf.libs.ModelBlockTypeEnum
import
ModelBlockTypeEnum
from
magic_pdf.libs.ocr_content_type
import
CategoryId
,
ContentType
from
magic_pdf.pre_proc.remove_bbox_overlap
import
_remove_overlap_between_bbox
from
magic_pdf.rw.AbsReaderWriter
import
AbsReaderWriter
from
magic_pdf.rw.DiskReaderWriter
import
DiskReaderWriter
...
...
@@ -24,7 +26,7 @@ class MagicModel:
need_remove_list
=
[]
page_no
=
model_page_info
[
'page_info'
][
'page_no'
]
horizontal_scale_ratio
,
vertical_scale_ratio
=
get_scale_ratio
(
model_page_info
,
self
.
__docs
[
page_no
]
model_page_info
,
self
.
__docs
.
get_page
(
page_no
)
)
layout_dets
=
model_page_info
[
'layout_dets'
]
for
layout_det
in
layout_dets
:
...
...
@@ -99,7 +101,7 @@ class MagicModel:
for
need_remove
in
need_remove_list
:
layout_dets
.
remove
(
need_remove
)
def
__init__
(
self
,
model_list
:
list
,
docs
:
fitz
.
Documen
t
):
def
__init__
(
self
,
model_list
:
list
,
docs
:
Datase
t
):
self
.
__model_list
=
model_list
self
.
__docs
=
docs
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
...
...
@@ -119,15 +121,13 @@ class MagicModel:
if
left
or
right
:
l1
=
bbox1
[
3
]
-
bbox1
[
1
]
l2
=
bbox2
[
3
]
-
bbox2
[
1
]
minL
,
maxL
=
min
(
l1
,
l2
),
max
(
l1
,
l2
)
if
(
maxL
-
minL
)
/
minL
>
0.5
:
return
float
(
'inf'
)
if
bottom
or
top
:
else
:
l1
=
bbox1
[
2
]
-
bbox1
[
0
]
l2
=
bbox2
[
2
]
-
bbox2
[
0
]
minL
,
maxL
=
min
(
l1
,
l2
),
max
(
l1
,
l2
)
if
(
maxL
-
minL
)
/
minL
>
0.5
:
return
float
(
'inf'
)
if
l2
>
l1
and
(
l2
-
l1
)
/
l1
>
0.3
:
return
float
(
'inf'
)
return
bbox_distance
(
bbox1
,
bbox2
)
def
__fix_footnote
(
self
):
...
...
@@ -215,9 +215,8 @@ class MagicModel:
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离
"""
def
search_overlap_between_boxes
(
subject_idx
,
object_idx
):
def
search_overlap_between_boxes
(
subject_idx
,
object_idx
):
idxes
=
[
subject_idx
,
object_idx
]
x0s
=
[
all_bboxes
[
idx
][
'bbox'
][
0
]
for
idx
in
idxes
]
y0s
=
[
all_bboxes
[
idx
][
'bbox'
][
1
]
for
idx
in
idxes
]
...
...
@@ -245,9 +244,9 @@ class MagicModel:
for
other_object
in
other_objects
:
ratio
=
max
(
ratio
,
get_overlap_area
(
merged_bbox
,
other_object
[
'bbox'
]
)
*
1.0
/
box_area
(
all_bboxes
[
object_idx
][
'bbox'
])
get_overlap_area
(
merged_bbox
,
other_object
[
'bbox'
])
*
1.0
/
box_area
(
all_bboxes
[
object_idx
][
'bbox'
])
,
)
if
ratio
>=
MERGE_BOX_OVERLAP_AREA_RATIO
:
break
...
...
@@ -365,12 +364,17 @@ class MagicModel:
if
all_bboxes
[
j
][
'category_id'
]
==
subject_category_id
:
subject_idx
,
object_idx
=
j
,
i
if
search_overlap_between_boxes
(
subject_idx
,
object_idx
)
>=
MERGE_BOX_OVERLAP_AREA_RATIO
:
if
(
search_overlap_between_boxes
(
subject_idx
,
object_idx
)
>=
MERGE_BOX_OVERLAP_AREA_RATIO
):
dis
[
i
][
j
]
=
float
(
'inf'
)
dis
[
j
][
i
]
=
dis
[
i
][
j
]
continue
dis
[
i
][
j
]
=
self
.
_bbox_distance
(
all_bboxes
[
i
][
'bbox'
],
all_bboxes
[
j
][
'bbox'
])
dis
[
i
][
j
]
=
self
.
_bbox_distance
(
all_bboxes
[
subject_idx
][
'bbox'
],
all_bboxes
[
object_idx
][
'bbox'
]
)
dis
[
j
][
i
]
=
dis
[
i
][
j
]
used
=
set
()
...
...
@@ -461,7 +465,7 @@ class MagicModel:
if
is_nearest
:
nx0
,
ny0
,
nx1
,
ny1
=
expand_bbbox
(
list
(
seen
)
+
[
k
])
n_dis
=
self
.
_
bbox_distance
(
n_dis
=
bbox_distance
(
all_bboxes
[
i
][
'bbox'
],
[
nx0
,
ny0
,
nx1
,
ny1
]
)
if
float_gt
(
dis
[
i
][
j
],
n_dis
):
...
...
@@ -557,7 +561,7 @@ class MagicModel:
# 计算已经配对的 distance 距离
for
i
in
subject_object_relation_map
.
keys
():
for
j
in
subject_object_relation_map
[
i
]:
total_subject_object_dis
+=
self
.
_
bbox_distance
(
total_subject_object_dis
+=
bbox_distance
(
all_bboxes
[
i
][
'bbox'
],
all_bboxes
[
j
][
'bbox'
]
)
...
...
@@ -586,6 +590,245 @@ class MagicModel:
with_caption_subject
.
add
(
j
)
return
ret
,
total_subject_object_dis
def
__tie_up_category_by_distance_v2
(
self
,
page_no
,
subject_category_id
,
object_category_id
):
AXIS_MULPLICITY
=
0.5
subjects
=
self
.
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
subject_category_id
,
self
.
__model_list
[
page_no
][
'layout_dets'
],
),
)
)
)
objects
=
self
.
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
object_category_id
,
self
.
__model_list
[
page_no
][
'layout_dets'
],
),
)
)
)
M
=
len
(
objects
)
subjects
.
sort
(
key
=
lambda
x
:
x
[
'bbox'
][
0
]
**
2
+
x
[
'bbox'
][
1
]
**
2
)
objects
.
sort
(
key
=
lambda
x
:
x
[
'bbox'
][
0
]
**
2
+
x
[
'bbox'
][
1
]
**
2
)
sub_obj_map_h
=
{
i
:
[]
for
i
in
range
(
len
(
subjects
))}
dis_by_directions
=
{
'top'
:
[[
-
1
,
float
(
'inf'
)]]
*
M
,
'bottom'
:
[[
-
1
,
float
(
'inf'
)]]
*
M
,
'left'
:
[[
-
1
,
float
(
'inf'
)]]
*
M
,
'right'
:
[[
-
1
,
float
(
'inf'
)]]
*
M
,
}
for
i
,
obj
in
enumerate
(
objects
):
l_x_axis
,
l_y_axis
=
(
obj
[
'bbox'
][
2
]
-
obj
[
'bbox'
][
0
],
obj
[
'bbox'
][
3
]
-
obj
[
'bbox'
][
1
],
)
axis_unit
=
min
(
l_x_axis
,
l_y_axis
)
for
j
,
sub
in
enumerate
(
subjects
):
bbox1
,
bbox2
,
_
=
_remove_overlap_between_bbox
(
objects
[
i
][
'bbox'
],
subjects
[
j
][
'bbox'
]
)
left
,
right
,
bottom
,
top
=
bbox_relative_pos
(
bbox1
,
bbox2
)
flags
=
[
left
,
right
,
bottom
,
top
]
if
sum
([
1
if
v
else
0
for
v
in
flags
])
>
1
:
continue
if
left
:
if
dis_by_directions
[
'left'
][
i
][
1
]
>
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]
):
dis_by_directions
[
'left'
][
i
]
=
[
j
,
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]),
]
if
right
:
if
dis_by_directions
[
'right'
][
i
][
1
]
>
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]
):
dis_by_directions
[
'right'
][
i
]
=
[
j
,
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]),
]
if
bottom
:
if
dis_by_directions
[
'bottom'
][
i
][
1
]
>
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]
):
dis_by_directions
[
'bottom'
][
i
]
=
[
j
,
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]),
]
if
top
:
if
dis_by_directions
[
'top'
][
i
][
1
]
>
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]
):
dis_by_directions
[
'top'
][
i
]
=
[
j
,
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]),
]
if
dis_by_directions
[
'left'
][
i
][
1
]
!=
float
(
'inf'
)
or
dis_by_directions
[
'right'
][
i
][
1
]
!=
float
(
'inf'
):
if
dis_by_directions
[
'left'
][
i
][
1
]
!=
float
(
'inf'
)
and
dis_by_directions
[
'right'
][
i
][
1
]
!=
float
(
'inf'
):
if
AXIS_MULPLICITY
*
axis_unit
>=
abs
(
dis_by_directions
[
'left'
][
i
][
1
]
-
dis_by_directions
[
'right'
][
i
][
1
]
):
left_sub_bbox
=
subjects
[
dis_by_directions
[
'left'
][
i
][
0
]][
'bbox'
]
right_sub_bbox
=
subjects
[
dis_by_directions
[
'right'
][
i
][
0
]][
'bbox'
]
left_sub_bbox_y_axis
=
left_sub_bbox
[
3
]
-
left_sub_bbox
[
1
]
right_sub_bbox_y_axis
=
right_sub_bbox
[
3
]
-
right_sub_bbox
[
1
]
if
(
abs
(
left_sub_bbox_y_axis
-
l_y_axis
)
+
dis_by_directions
[
'left'
][
i
][
0
]
>
abs
(
right_sub_bbox_y_axis
-
l_y_axis
)
+
dis_by_directions
[
'right'
][
i
][
0
]
):
left_or_right
=
dis_by_directions
[
'right'
][
i
]
else
:
left_or_right
=
dis_by_directions
[
'left'
][
i
]
else
:
left_or_right
=
dis_by_directions
[
'left'
][
i
]
if
left_or_right
[
1
]
>
dis_by_directions
[
'right'
][
i
][
1
]:
left_or_right
=
dis_by_directions
[
'right'
][
i
]
else
:
left_or_right
=
dis_by_directions
[
'left'
][
i
]
if
left_or_right
[
1
]
==
float
(
'inf'
):
left_or_right
=
dis_by_directions
[
'right'
][
i
]
else
:
left_or_right
=
[
-
1
,
float
(
'inf'
)]
if
dis_by_directions
[
'top'
][
i
][
1
]
!=
float
(
'inf'
)
or
dis_by_directions
[
'bottom'
][
i
][
1
]
!=
float
(
'inf'
):
if
dis_by_directions
[
'top'
][
i
][
1
]
!=
float
(
'inf'
)
and
dis_by_directions
[
'bottom'
][
i
][
1
]
!=
float
(
'inf'
):
if
AXIS_MULPLICITY
*
axis_unit
>=
abs
(
dis_by_directions
[
'top'
][
i
][
1
]
-
dis_by_directions
[
'bottom'
][
i
][
1
]
):
top_bottom
=
subjects
[
dis_by_directions
[
'bottom'
][
i
][
0
]][
'bbox'
]
bottom_top
=
subjects
[
dis_by_directions
[
'top'
][
i
][
0
]][
'bbox'
]
top_bottom_x_axis
=
top_bottom
[
2
]
-
top_bottom
[
0
]
bottom_top_x_axis
=
bottom_top
[
2
]
-
bottom_top
[
0
]
if
abs
(
top_bottom_x_axis
-
l_x_axis
)
+
dis_by_directions
[
'bottom'
][
i
][
1
]
>
abs
(
bottom_top_x_axis
-
l_x_axis
)
+
dis_by_directions
[
'top'
][
i
][
1
]:
top_or_bottom
=
dis_by_directions
[
'top'
][
i
]
else
:
top_or_bottom
=
dis_by_directions
[
'bottom'
][
i
]
else
:
top_or_bottom
=
dis_by_directions
[
'top'
][
i
]
if
top_or_bottom
[
1
]
>
dis_by_directions
[
'bottom'
][
i
][
1
]:
top_or_bottom
=
dis_by_directions
[
'bottom'
][
i
]
else
:
top_or_bottom
=
dis_by_directions
[
'top'
][
i
]
if
top_or_bottom
[
1
]
==
float
(
'inf'
):
top_or_bottom
=
dis_by_directions
[
'bottom'
][
i
]
else
:
top_or_bottom
=
[
-
1
,
float
(
'inf'
)]
if
left_or_right
[
1
]
!=
float
(
'inf'
)
or
top_or_bottom
[
1
]
!=
float
(
'inf'
):
if
left_or_right
[
1
]
!=
float
(
'inf'
)
and
top_or_bottom
[
1
]
!=
float
(
'inf'
):
if
AXIS_MULPLICITY
*
axis_unit
>=
abs
(
left_or_right
[
1
]
-
top_or_bottom
[
1
]
):
y_axis_bbox
=
subjects
[
left_or_right
[
0
]][
'bbox'
]
x_axis_bbox
=
subjects
[
top_or_bottom
[
0
]][
'bbox'
]
if
(
abs
((
x_axis_bbox
[
2
]
-
x_axis_bbox
[
0
])
-
l_x_axis
)
/
l_x_axis
>
abs
((
y_axis_bbox
[
3
]
-
y_axis_bbox
[
1
])
-
l_y_axis
)
/
l_y_axis
):
sub_obj_map_h
[
left_or_right
[
0
]].
append
(
i
)
else
:
sub_obj_map_h
[
top_or_bottom
[
0
]].
append
(
i
)
else
:
if
left_or_right
[
1
]
>
top_or_bottom
[
1
]:
sub_obj_map_h
[
top_or_bottom
[
0
]].
append
(
i
)
else
:
sub_obj_map_h
[
left_or_right
[
0
]].
append
(
i
)
else
:
if
left_or_right
[
1
]
!=
float
(
'inf'
):
sub_obj_map_h
[
left_or_right
[
0
]].
append
(
i
)
else
:
sub_obj_map_h
[
top_or_bottom
[
0
]].
append
(
i
)
ret
=
[]
for
i
in
sub_obj_map_h
.
keys
():
ret
.
append
(
{
'sub_bbox'
:
{
'bbox'
:
subjects
[
i
][
'bbox'
],
'score'
:
subjects
[
i
][
'score'
],
},
'obj_bboxes'
:
[
{
'score'
:
objects
[
j
][
'score'
],
'bbox'
:
objects
[
j
][
'bbox'
]}
for
j
in
sub_obj_map_h
[
i
]
],
'sub_idx'
:
i
,
}
)
return
ret
def
get_imgs_v2
(
self
,
page_no
:
int
):
with_captions
=
self
.
__tie_up_category_by_distance_v2
(
page_no
,
3
,
4
)
with_footnotes
=
self
.
__tie_up_category_by_distance_v2
(
page_no
,
3
,
CategoryId
.
ImageFootnote
)
ret
=
[]
for
v
in
with_captions
:
record
=
{
'image_body'
:
v
[
'sub_bbox'
],
'image_caption_list'
:
v
[
'obj_bboxes'
],
}
filter_idx
=
v
[
'sub_idx'
]
d
=
next
(
filter
(
lambda
x
:
x
[
'sub_idx'
]
==
filter_idx
,
with_footnotes
))
record
[
'image_footnote_list'
]
=
d
[
'obj_bboxes'
]
ret
.
append
(
record
)
return
ret
def
get_tables_v2
(
self
,
page_no
:
int
)
->
list
:
with_captions
=
self
.
__tie_up_category_by_distance_v2
(
page_no
,
5
,
6
)
with_footnotes
=
self
.
__tie_up_category_by_distance_v2
(
page_no
,
5
,
7
)
ret
=
[]
for
v
in
with_captions
:
record
=
{
'table_body'
:
v
[
'sub_bbox'
],
'table_caption_list'
:
v
[
'obj_bboxes'
],
}
filter_idx
=
v
[
'sub_idx'
]
d
=
next
(
filter
(
lambda
x
:
x
[
'sub_idx'
]
==
filter_idx
,
with_footnotes
))
record
[
'table_footnote_list'
]
=
d
[
'obj_bboxes'
]
ret
.
append
(
record
)
return
ret
def
get_imgs
(
self
,
page_no
:
int
):
with_captions
,
_
=
self
.
__tie_up_category_by_distance
(
page_no
,
3
,
4
)
with_footnotes
,
_
=
self
.
__tie_up_category_by_distance
(
...
...
@@ -719,10 +962,10 @@ class MagicModel:
def
get_page_size
(
self
,
page_no
:
int
):
# 获取页面宽高
# 获取当前页的page对象
page
=
self
.
__docs
[
page_no
]
page
=
self
.
__docs
.
get_page
(
page_no
).
get_page_info
()
# 获取当前页的宽高
page_w
=
page
.
rect
.
width
page_h
=
page
.
rect
.
height
page_w
=
page
.
w
page_h
=
page
.
h
return
page_w
,
page_h
def
__get_blocks_by_type
(
...
...
magic_pdf/model/pdf_extract_kit.py
View file @
6d571e2e
...
...
@@ -26,6 +26,7 @@ try:
from
unimernet.common.config
import
Config
import
unimernet.tasks
as
tasks
from
unimernet.processors
import
load_processor
from
doclayout_yolo
import
YOLOv10
except
ImportError
as
e
:
logger
.
exception
(
e
)
...
...
@@ -42,7 +43,7 @@ from magic_pdf.model.ppTableModel import ppTableModel
def
table_model_init
(
table_model_type
,
model_path
,
max_time
,
_device_
=
'cpu'
):
if
table_model_type
==
STRUCT_EQTABLE
:
if
table_model_type
==
MODEL_NAME
.
STRUCT_EQTABLE
:
table_model
=
StructTableModel
(
model_path
,
max_time
=
max_time
,
device
=
_device_
)
else
:
config
=
{
...
...
@@ -83,11 +84,16 @@ def layout_model_init(weight, config_file, device):
return
model
def
ocr_model_init
(
show_log
:
bool
=
False
,
det_db_box_thresh
=
0.3
,
lang
=
None
):
def
doclayout_yolo_model_init
(
weight
):
model
=
YOLOv10
(
weight
)
return
model
def
ocr_model_init
(
show_log
:
bool
=
False
,
det_db_box_thresh
=
0.3
,
lang
=
None
,
use_dilation
=
True
,
det_db_unclip_ratio
=
1.8
):
if
lang
is
not
None
:
model
=
ModifiedPaddleOCR
(
show_log
=
show_log
,
det_db_box_thresh
=
det_db_box_thresh
,
lang
=
lang
)
model
=
ModifiedPaddleOCR
(
show_log
=
show_log
,
det_db_box_thresh
=
det_db_box_thresh
,
lang
=
lang
,
use_dilation
=
use_dilation
,
det_db_unclip_ratio
=
det_db_unclip_ratio
)
else
:
model
=
ModifiedPaddleOCR
(
show_log
=
show_log
,
det_db_box_thresh
=
det_db_box_thresh
)
model
=
ModifiedPaddleOCR
(
show_log
=
show_log
,
det_db_box_thresh
=
det_db_box_thresh
,
use_dilation
=
use_dilation
,
det_db_unclip_ratio
=
det_db_unclip_ratio
)
return
model
...
...
@@ -120,19 +126,27 @@ class AtomModelSingleton:
return
cls
.
_instance
def
get_atom_model
(
self
,
atom_model_name
:
str
,
**
kwargs
):
if
atom_model_name
not
in
self
.
_models
:
self
.
_models
[
atom_model_name
]
=
atom_model_init
(
model_name
=
atom_model_name
,
**
kwargs
)
return
self
.
_models
[
atom_model_name
]
lang
=
kwargs
.
get
(
"lang"
,
None
)
layout_model_name
=
kwargs
.
get
(
"layout_model_name"
,
None
)
key
=
(
atom_model_name
,
layout_model_name
,
lang
)
if
key
not
in
self
.
_models
:
self
.
_models
[
key
]
=
atom_model_init
(
model_name
=
atom_model_name
,
**
kwargs
)
return
self
.
_models
[
key
]
def
atom_model_init
(
model_name
:
str
,
**
kwargs
):
if
model_name
==
AtomicModel
.
Layout
:
atom_model
=
layout_model_init
(
kwargs
.
get
(
"layout_weights"
),
kwargs
.
get
(
"layout_config_file"
),
kwargs
.
get
(
"device"
)
)
if
kwargs
.
get
(
"layout_model_name"
)
==
MODEL_NAME
.
LAYOUTLMv3
:
atom_model
=
layout_model_init
(
kwargs
.
get
(
"layout_weights"
),
kwargs
.
get
(
"layout_config_file"
),
kwargs
.
get
(
"device"
)
)
elif
kwargs
.
get
(
"layout_model_name"
)
==
MODEL_NAME
.
DocLayout_YOLO
:
atom_model
=
doclayout_yolo_model_init
(
kwargs
.
get
(
"doclayout_yolo_weights"
),
)
elif
model_name
==
AtomicModel
.
MFD
:
atom_model
=
mfd_model_init
(
kwargs
.
get
(
"mfd_weights"
)
...
...
@@ -151,7 +165,7 @@ def atom_model_init(model_name: str, **kwargs):
)
elif
model_name
==
AtomicModel
.
Table
:
atom_model
=
table_model_init
(
kwargs
.
get
(
"table_model_
typ
e"
),
kwargs
.
get
(
"table_model_
nam
e"
),
kwargs
.
get
(
"table_model_path"
),
kwargs
.
get
(
"table_max_time"
),
kwargs
.
get
(
"device"
)
...
...
@@ -199,23 +213,35 @@ class CustomPEKModel:
with
open
(
config_path
,
"r"
,
encoding
=
'utf-8'
)
as
f
:
self
.
configs
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
# 初始化解析配置
self
.
apply_layout
=
kwargs
.
get
(
"apply_layout"
,
self
.
configs
[
"config"
][
"layout"
])
self
.
apply_formula
=
kwargs
.
get
(
"apply_formula"
,
self
.
configs
[
"config"
][
"formula"
])
# layout config
self
.
layout_config
=
kwargs
.
get
(
"layout_config"
)
self
.
layout_model_name
=
self
.
layout_config
.
get
(
"model"
,
MODEL_NAME
.
DocLayout_YOLO
)
# formula config
self
.
formula_config
=
kwargs
.
get
(
"formula_config"
)
self
.
mfd_model_name
=
self
.
formula_config
.
get
(
"mfd_model"
,
MODEL_NAME
.
YOLO_V8_MFD
)
self
.
mfr_model_name
=
self
.
formula_config
.
get
(
"mfr_model"
,
MODEL_NAME
.
UniMerNet_v2_Small
)
self
.
apply_formula
=
self
.
formula_config
.
get
(
"enable"
,
True
)
# table config
self
.
table_config
=
kwargs
.
get
(
"table_config"
,
self
.
configs
[
"config"
][
"table_config"
]
)
self
.
apply_table
=
self
.
table_config
.
get
(
"
is_table_recog_
enable"
,
False
)
self
.
table_config
=
kwargs
.
get
(
"table_config"
)
self
.
apply_table
=
self
.
table_config
.
get
(
"enable"
,
False
)
self
.
table_max_time
=
self
.
table_config
.
get
(
"max_time"
,
TABLE_MAX_TIME_VALUE
)
self
.
table_model_type
=
self
.
table_config
.
get
(
"model"
,
TABLE_MASTER
)
self
.
table_model_name
=
self
.
table_config
.
get
(
"model"
,
MODEL_NAME
.
TABLE_MASTER
)
# ocr config
self
.
apply_ocr
=
ocr
self
.
lang
=
kwargs
.
get
(
"lang"
,
None
)
logger
.
info
(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}, lang: {}"
.
format
(
self
.
apply_layout
,
self
.
apply_formula
,
self
.
apply_ocr
,
self
.
apply_table
,
self
.
lang
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
"apply_table: {}, table_model: {}, lang: {}"
.
format
(
self
.
layout_model_name
,
self
.
apply_formula
,
self
.
apply_ocr
,
self
.
apply_table
,
self
.
table_model_name
,
self
.
lang
)
)
assert
self
.
apply_layout
,
"DocAnalysis must contain layout model."
# 初始化解析方案
self
.
device
=
kwargs
.
get
(
"device"
,
self
.
configs
[
"config"
][
"device"
]
)
self
.
device
=
kwargs
.
get
(
"device"
,
"cpu"
)
logger
.
info
(
"using device: {}"
.
format
(
self
.
device
))
models_dir
=
kwargs
.
get
(
"models_dir"
,
os
.
path
.
join
(
root_dir
,
"resources"
,
"models"
))
logger
.
info
(
"using models_dir: {}"
.
format
(
models_dir
))
...
...
@@ -224,17 +250,16 @@ class CustomPEKModel:
# 初始化公式识别
if
self
.
apply_formula
:
# 初始化公式检测模型
# self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
self
.
mfd_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
MFD
,
mfd_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
"mfd"
]))
mfd_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
self
.
mfd_model_name
]))
)
# 初始化公式解析模型
mfr_weight_dir
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
"mfr"
]))
mfr_weight_dir
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
self
.
mfr_model_name
]))
mfr_cfg_path
=
str
(
os
.
path
.
join
(
model_config_dir
,
"UniMERNet"
,
"demo.yaml"
))
# self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
# self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
self
.
mfr_model
,
self
.
mfr_transform
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
MFR
,
mfr_weight_dir
=
mfr_weight_dir
,
...
...
@@ -243,17 +268,20 @@ class CustomPEKModel:
)
# 初始化layout模型
# self.layout_model = Layoutlmv3_Predictor(
# str(os.path.join(models_dir, self.configs['weights']['layout'])),
# str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
# device=self.device
# )
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Layout
,
layout_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
'layout'
])),
layout_config_file
=
str
(
os
.
path
.
join
(
model_config_dir
,
"layoutlmv3"
,
"layoutlmv3_base_inference.yaml"
)),
device
=
self
.
device
)
if
self
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Layout
,
layout_model_name
=
MODEL_NAME
.
LAYOUTLMv3
,
layout_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
layout_model_name
])),
layout_config_file
=
str
(
os
.
path
.
join
(
model_config_dir
,
"layoutlmv3"
,
"layoutlmv3_base_inference.yaml"
)),
device
=
self
.
device
)
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Layout
,
layout_model_name
=
MODEL_NAME
.
DocLayout_YOLO
,
doclayout_yolo_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
layout_model_name
]))
)
# 初始化ocr
if
self
.
apply_ocr
:
...
...
@@ -266,12 +294,10 @@ class CustomPEKModel:
)
# init table model
if
self
.
apply_table
:
table_model_dir
=
self
.
configs
[
"weights"
][
self
.
table_model_type
]
# self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
# max_time=self.table_max_time, _device_=self.device)
table_model_dir
=
self
.
configs
[
"weights"
][
self
.
table_model_name
]
self
.
table_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Table
,
table_model_
typ
e
=
self
.
table_model_
typ
e
,
table_model_
nam
e
=
self
.
table_model_
nam
e
,
table_model_path
=
str
(
os
.
path
.
join
(
models_dir
,
table_model_dir
)),
table_max_time
=
self
.
table_max_time
,
device
=
self
.
device
...
...
@@ -288,7 +314,21 @@ class CustomPEKModel:
# layout检测
layout_start
=
time
.
time
()
layout_res
=
self
.
layout_model
(
image
,
ignore_catids
=
[])
if
self
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
# layoutlmv3
layout_res
=
self
.
layout_model
(
image
,
ignore_catids
=
[])
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
# doclayout_yolo
layout_res
=
[]
doclayout_yolo_res
=
self
.
layout_model
.
predict
(
image
,
imgsz
=
1024
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
)[
0
]
for
xyxy
,
conf
,
cla
in
zip
(
doclayout_yolo_res
.
boxes
.
xyxy
.
cpu
(),
doclayout_yolo_res
.
boxes
.
conf
.
cpu
(),
doclayout_yolo_res
.
boxes
.
cls
.
cpu
()):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
new_item
=
{
'category_id'
:
int
(
cla
.
item
()),
'poly'
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
'score'
:
round
(
float
(
conf
.
item
()),
3
),
}
layout_res
.
append
(
new_item
)
layout_cost
=
round
(
time
.
time
()
-
layout_start
,
2
)
logger
.
info
(
f
"layout detection time:
{
layout_cost
}
"
)
...
...
@@ -297,7 +337,7 @@ class CustomPEKModel:
if
self
.
apply_formula
:
# 公式检测
mfd_start
=
time
.
time
()
mfd_res
=
self
.
mfd_model
.
predict
(
image
,
imgsz
=
1888
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
)[
0
]
mfd_res
=
self
.
mfd_model
.
predict
(
image
,
imgsz
=
1888
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
)[
0
]
logger
.
info
(
f
"mfd time:
{
round
(
time
.
time
()
-
mfd_start
,
2
)
}
"
)
for
xyxy
,
conf
,
cla
in
zip
(
mfd_res
.
boxes
.
xyxy
.
cpu
(),
mfd_res
.
boxes
.
conf
.
cpu
(),
mfd_res
.
boxes
.
cls
.
cpu
()):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
...
...
@@ -309,7 +349,6 @@ class CustomPEKModel:
}
layout_res
.
append
(
new_item
)
latex_filling_list
.
append
(
new_item
)
# bbox_img = get_croped_image(pil_img, [xmin, ymin, xmax, ymax])
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
ymax
))
mf_image_list
.
append
(
bbox_img
)
...
...
@@ -346,7 +385,7 @@ class CustomPEKModel:
if
torch
.
cuda
.
is_available
():
properties
=
torch
.
cuda
.
get_device_properties
(
self
.
device
)
total_memory
=
properties
.
total_memory
/
(
1024
**
3
)
# 将字节转换为 GB
if
total_memory
<=
8
:
if
total_memory
<=
10
:
gc_start
=
time
.
time
()
clean_memory
()
gc_time
=
round
(
time
.
time
()
-
gc_start
,
2
)
...
...
@@ -411,7 +450,7 @@ class CustomPEKModel:
# logger.info("------------------table recognition processing begins-----------------")
latex_code
=
None
html_code
=
None
if
self
.
table_model_
typ
e
==
STRUCT_EQTABLE
:
if
self
.
table_model_
nam
e
==
MODEL_NAME
.
STRUCT_EQTABLE
:
with
torch
.
no_grad
():
latex_code
=
self
.
table_model
.
image2latex
(
new_image
)[
0
]
else
:
...
...
magic_pdf/model/ppTableModel.py
View file @
6d571e2e
...
...
@@ -52,11 +52,11 @@ class ppTableModel(object):
rec_model_dir
=
os
.
path
.
join
(
model_dir
,
REC_MODEL_DIR
)
rec_char_dict_path
=
os
.
path
.
join
(
model_dir
,
REC_CHAR_DICT
)
device
=
kwargs
.
get
(
"device"
,
"cpu"
)
use_gpu
=
True
if
device
==
"cuda"
else
False
use_gpu
=
True
if
device
.
startswith
(
"cuda"
)
else
False
config
=
{
"use_gpu"
:
use_gpu
,
"table_max_len"
:
kwargs
.
get
(
"table_max_len"
,
TABLE_MAX_LEN
),
"table_algorithm"
:
TABLE_MASTER
,
"table_algorithm"
:
"TableMaster"
,
"table_model_dir"
:
table_model_dir
,
"table_char_dict_path"
:
table_char_dict_path
,
"det_model_dir"
:
det_model_dir
,
...
...
magic_pdf/para/para_split_v3.py
0 → 100644
View file @
6d571e2e
import
copy
from
loguru
import
logger
from
magic_pdf.libs.Constants
import
LINES_DELETED
,
CROSS_PAGE
from
magic_pdf.libs.ocr_content_type
import
BlockType
,
ContentType
LINE_STOP_FLAG
=
(
'.'
,
'!'
,
'?'
,
'。'
,
'!'
,
'?'
,
')'
,
')'
,
'"'
,
'”'
,
':'
,
':'
,
';'
,
';'
)
LIST_END_FLAG
=
(
'.'
,
'。'
,
';'
,
';'
)
class
ListLineTag
:
IS_LIST_START_LINE
=
"is_list_start_line"
IS_LIST_END_LINE
=
"is_list_end_line"
def
__process_blocks
(
blocks
):
# 对所有block预处理
# 1.通过title和interline_equation将block分组
# 2.bbox边界根据line信息重置
result
=
[]
current_group
=
[]
for
i
in
range
(
len
(
blocks
)):
current_block
=
blocks
[
i
]
# 如果当前块是 text 类型
if
current_block
[
'type'
]
==
'text'
:
current_block
[
"bbox_fs"
]
=
copy
.
deepcopy
(
current_block
[
"bbox"
])
if
'lines'
in
current_block
and
len
(
current_block
[
"lines"
])
>
0
:
current_block
[
'bbox_fs'
]
=
[
min
([
line
[
'bbox'
][
0
]
for
line
in
current_block
[
'lines'
]]),
min
([
line
[
'bbox'
][
1
]
for
line
in
current_block
[
'lines'
]]),
max
([
line
[
'bbox'
][
2
]
for
line
in
current_block
[
'lines'
]]),
max
([
line
[
'bbox'
][
3
]
for
line
in
current_block
[
'lines'
]])]
current_group
.
append
(
current_block
)
# 检查下一个块是否存在
if
i
+
1
<
len
(
blocks
):
next_block
=
blocks
[
i
+
1
]
# 如果下一个块不是 text 类型且是 title 或 interline_equation 类型
if
next_block
[
'type'
]
in
[
'title'
,
'interline_equation'
]:
result
.
append
(
current_group
)
current_group
=
[]
# 处理最后一个 group
if
current_group
:
result
.
append
(
current_group
)
return
result
def
__is_list_or_index_block
(
block
):
# 一个block如果是list block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 右侧不顶格(狗牙状)
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.多个line以endflag结尾
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 左侧不顶格
# index block 是一种特殊的list block
# 一个block如果是index block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line两侧均顶格写 3.line的开头或者结尾均为数字
if
len
(
block
[
'lines'
])
>=
2
:
first_line
=
block
[
'lines'
][
0
]
line_height
=
first_line
[
'bbox'
][
3
]
-
first_line
[
'bbox'
][
1
]
block_weight
=
block
[
'bbox_fs'
][
2
]
-
block
[
'bbox_fs'
][
0
]
left_close_num
=
0
left_not_close_num
=
0
right_not_close_num
=
0
right_close_num
=
0
lines_text_list
=
[]
multiple_para_flag
=
False
last_line
=
block
[
'lines'
][
-
1
]
# 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格)
if
(
first_line
[
'bbox'
][
0
]
-
block
[
'bbox_fs'
][
0
]
>
line_height
/
2
and
# block['bbox_fs'][2] - first_line['bbox'][2] < line_height and
abs
(
last_line
[
'bbox'
][
0
]
-
block
[
'bbox_fs'
][
0
])
<
line_height
/
2
and
block
[
'bbox_fs'
][
2
]
-
last_line
[
'bbox'
][
2
]
>
line_height
):
multiple_para_flag
=
True
for
line
in
block
[
'lines'
]:
line_text
=
""
for
span
in
line
[
'spans'
]:
span_type
=
span
[
'type'
]
if
span_type
==
ContentType
.
Text
:
line_text
+=
span
[
'content'
].
strip
()
lines_text_list
.
append
(
line_text
)
# 计算line左侧顶格数量是否大于2,是否顶格用abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2 来判断
if
abs
(
block
[
'bbox_fs'
][
0
]
-
line
[
'bbox'
][
0
])
<
line_height
/
2
:
left_close_num
+=
1
elif
line
[
'bbox'
][
0
]
-
block
[
'bbox_fs'
][
0
]
>
line_height
:
# logger.info(f"{line_text}, {block['bbox_fs']}, {line['bbox']}")
left_not_close_num
+=
1
# 计算右侧是否顶格
if
abs
(
block
[
'bbox_fs'
][
2
]
-
line
[
'bbox'
][
2
])
<
line_height
:
right_close_num
+=
1
else
:
# 右侧不顶格情况下是否有一段距离,拍脑袋用0.3block宽度做阈值
closed_area
=
0.3
*
block_weight
# closed_area = 5 * line_height
if
block
[
'bbox_fs'
][
2
]
-
line
[
'bbox'
][
2
]
>
closed_area
:
right_not_close_num
+=
1
# 判断lines_text_list中的元素是否有超过80%都以LIST_END_FLAG结尾
line_end_flag
=
False
# 判断lines_text_list中的元素是否有超过80%都以数字开头或都以数字结尾
line_num_flag
=
False
num_start_count
=
0
num_end_count
=
0
flag_end_count
=
0
if
len
(
lines_text_list
)
>
0
:
for
line_text
in
lines_text_list
:
if
len
(
line_text
)
>
0
:
if
line_text
[
-
1
]
in
LIST_END_FLAG
:
flag_end_count
+=
1
if
line_text
[
0
].
isdigit
():
num_start_count
+=
1
if
line_text
[
-
1
].
isdigit
():
num_end_count
+=
1
if
flag_end_count
/
len
(
lines_text_list
)
>=
0.8
:
line_end_flag
=
True
if
num_start_count
/
len
(
lines_text_list
)
>=
0.8
or
num_end_count
/
len
(
lines_text_list
)
>=
0.8
:
line_num_flag
=
True
# 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index
if
((
left_close_num
/
len
(
block
[
'lines'
])
>=
0.8
or
right_close_num
/
len
(
block
[
'lines'
])
>=
0.8
)
and
line_num_flag
):
for
line
in
block
[
'lines'
]:
line
[
ListLineTag
.
IS_LIST_START_LINE
]
=
True
return
BlockType
.
Index
elif
left_close_num
>=
2
and
(
right_not_close_num
>=
2
or
line_end_flag
or
left_not_close_num
>=
2
)
and
not
multiple_para_flag
:
# 处理一种特殊的没有缩进的list,所有行都贴左边,通过右边的空隙判断是否是item尾
if
left_close_num
/
len
(
block
[
'lines'
])
>
0.9
:
# 这种是每个item只有一行,且左边都贴边的短item list
if
flag_end_count
==
0
and
right_close_num
/
len
(
block
[
'lines'
])
<
0.5
:
for
line
in
block
[
'lines'
]:
if
abs
(
block
[
'bbox_fs'
][
0
]
-
line
[
'bbox'
][
0
])
<
line_height
/
2
:
line
[
ListLineTag
.
IS_LIST_START_LINE
]
=
True
# 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item
elif
line_end_flag
:
for
i
,
line
in
enumerate
(
block
[
'lines'
]):
if
lines_text_list
[
i
][
-
1
]
in
LIST_END_FLAG
:
line
[
ListLineTag
.
IS_LIST_END_LINE
]
=
True
if
i
+
1
<
len
(
block
[
'lines'
]):
block
[
'lines'
][
i
+
1
][
ListLineTag
.
IS_LIST_START_LINE
]
=
True
# line item基本没有结束标识符,而且也没有缩进,按右侧空隙判断哪些是item end
else
:
line_start_flag
=
False
for
i
,
line
in
enumerate
(
block
[
'lines'
]):
if
line_start_flag
:
line
[
ListLineTag
.
IS_LIST_START_LINE
]
=
True
line_start_flag
=
False
elif
abs
(
block
[
'bbox_fs'
][
2
]
-
line
[
'bbox'
][
2
])
>
line_height
:
line
[
ListLineTag
.
IS_LIST_END_LINE
]
=
True
line_start_flag
=
True
# 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_LINE 结尾且数量和start line 一致
elif
num_start_count
>=
2
and
num_start_count
==
flag_end_count
:
# 简单一点先不考虑左侧不贴边的情况
for
i
,
line
in
enumerate
(
block
[
'lines'
]):
if
lines_text_list
[
i
][
0
].
isdigit
():
line
[
ListLineTag
.
IS_LIST_START_LINE
]
=
True
if
lines_text_list
[
i
][
-
1
]
in
LIST_END_FLAG
:
line
[
ListLineTag
.
IS_LIST_END_LINE
]
=
True
else
:
# 正常有缩进的list处理
for
line
in
block
[
'lines'
]:
if
abs
(
block
[
'bbox_fs'
][
0
]
-
line
[
'bbox'
][
0
])
<
line_height
/
2
:
line
[
ListLineTag
.
IS_LIST_START_LINE
]
=
True
if
abs
(
block
[
'bbox_fs'
][
2
]
-
line
[
'bbox'
][
2
])
>
line_height
:
line
[
ListLineTag
.
IS_LIST_END_LINE
]
=
True
return
BlockType
.
List
else
:
return
BlockType
.
Text
else
:
return
BlockType
.
Text
def
__merge_2_text_blocks
(
block1
,
block2
):
if
len
(
block1
[
'lines'
])
>
0
:
first_line
=
block1
[
'lines'
][
0
]
line_height
=
first_line
[
'bbox'
][
3
]
-
first_line
[
'bbox'
][
1
]
block1_weight
=
block1
[
'bbox'
][
2
]
-
block1
[
'bbox'
][
0
]
block2_weight
=
block2
[
'bbox'
][
2
]
-
block2
[
'bbox'
][
0
]
min_block_weight
=
min
(
block1_weight
,
block2_weight
)
if
abs
(
block1
[
'bbox_fs'
][
0
]
-
first_line
[
'bbox'
][
0
])
<
line_height
/
2
:
last_line
=
block2
[
'lines'
][
-
1
]
if
len
(
last_line
[
'spans'
])
>
0
:
last_span
=
last_line
[
'spans'
][
-
1
]
line_height
=
last_line
[
'bbox'
][
3
]
-
last_line
[
'bbox'
][
1
]
if
(
abs
(
block2
[
'bbox_fs'
][
2
]
-
last_line
[
'bbox'
][
2
])
<
line_height
and
not
last_span
[
'content'
].
endswith
(
LINE_STOP_FLAG
)
and
# 两个block宽度差距超过2倍也不合并
abs
(
block1_weight
-
block2_weight
)
<
min_block_weight
):
if
block1
[
'page_num'
]
!=
block2
[
'page_num'
]:
for
line
in
block1
[
'lines'
]:
for
span
in
line
[
'spans'
]:
span
[
CROSS_PAGE
]
=
True
block2
[
'lines'
].
extend
(
block1
[
'lines'
])
block1
[
'lines'
]
=
[]
block1
[
LINES_DELETED
]
=
True
return
block1
,
block2
def
__merge_2_list_blocks
(
block1
,
block2
):
if
block1
[
'page_num'
]
!=
block2
[
'page_num'
]:
for
line
in
block1
[
'lines'
]:
for
span
in
line
[
'spans'
]:
span
[
CROSS_PAGE
]
=
True
block2
[
'lines'
].
extend
(
block1
[
'lines'
])
block1
[
'lines'
]
=
[]
block1
[
LINES_DELETED
]
=
True
return
block1
,
block2
def
__is_list_group
(
text_blocks_group
):
# list group的特征是一个group内的所有block都满足以下条件
# 1.每个block都不超过3行 2. 每个block 的左边界都比较接近(逻辑简单点先不加这个规则)
for
block
in
text_blocks_group
:
if
len
(
block
[
'lines'
])
>
3
:
return
False
return
True
def
__para_merge_page
(
blocks
):
page_text_blocks_groups
=
__process_blocks
(
blocks
)
for
text_blocks_group
in
page_text_blocks_groups
:
if
len
(
text_blocks_group
)
>
0
:
# 需要先在合并前对所有block判断是否为list or index block
for
block
in
text_blocks_group
:
block_type
=
__is_list_or_index_block
(
block
)
block
[
'type'
]
=
block_type
# logger.info(f"{block['type']}:{block}")
if
len
(
text_blocks_group
)
>
1
:
# 在合并前判断这个group 是否是一个 list group
is_list_group
=
__is_list_group
(
text_blocks_group
)
# 倒序遍历
for
i
in
range
(
len
(
text_blocks_group
)
-
1
,
-
1
,
-
1
):
current_block
=
text_blocks_group
[
i
]
# 检查是否有前一个块
if
i
-
1
>=
0
:
prev_block
=
text_blocks_group
[
i
-
1
]
if
current_block
[
'type'
]
==
'text'
and
prev_block
[
'type'
]
==
'text'
and
not
is_list_group
:
__merge_2_text_blocks
(
current_block
,
prev_block
)
elif
(
(
current_block
[
'type'
]
==
BlockType
.
List
and
prev_block
[
'type'
]
==
BlockType
.
List
)
or
(
current_block
[
'type'
]
==
BlockType
.
Index
and
prev_block
[
'type'
]
==
BlockType
.
Index
)
):
__merge_2_list_blocks
(
current_block
,
prev_block
)
else
:
continue
def
para_split
(
pdf_info_dict
,
debug_mode
=
False
):
all_blocks
=
[]
for
page_num
,
page
in
pdf_info_dict
.
items
():
blocks
=
copy
.
deepcopy
(
page
[
'preproc_blocks'
])
for
block
in
blocks
:
block
[
'page_num'
]
=
page_num
all_blocks
.
extend
(
blocks
)
__para_merge_page
(
all_blocks
)
for
page_num
,
page
in
pdf_info_dict
.
items
():
page
[
'para_blocks'
]
=
[]
for
block
in
all_blocks
:
if
block
[
'page_num'
]
==
page_num
:
page
[
'para_blocks'
].
append
(
block
)
if
__name__
==
'__main__'
:
input_blocks
=
[]
# 调用函数
groups
=
__process_blocks
(
input_blocks
)
for
group_index
,
group
in
enumerate
(
groups
):
print
(
f
"Group
{
group_index
}
:
{
group
}
"
)
magic_pdf/pdf_parse_by_ocr.py
View file @
6d571e2e
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
from
magic_pdf.data.dataset
import
PymuDocDataset
from
magic_pdf.pdf_parse_union_core_v2
import
pdf_parse_union
...
...
@@ -8,10 +10,11 @@ def parse_pdf_by_ocr(pdf_bytes,
end_page_id
=
None
,
debug_mode
=
False
,
):
return
pdf_parse_union
(
pdf_bytes
,
dataset
=
PymuDocDataset
(
pdf_bytes
)
return
pdf_parse_union
(
dataset
,
model_list
,
imageWriter
,
"ocr"
,
SupportedPdfParseMethod
.
OCR
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
debug_mode
=
debug_mode
,
...
...
magic_pdf/pdf_parse_by_txt.py
View file @
6d571e2e
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
from
magic_pdf.data.dataset
import
PymuDocDataset
from
magic_pdf.pdf_parse_union_core_v2
import
pdf_parse_union
...
...
@@ -9,10 +11,11 @@ def parse_pdf_by_txt(
end_page_id
=
None
,
debug_mode
=
False
,
):
return
pdf_parse_union
(
pdf_bytes
,
dataset
=
PymuDocDataset
(
pdf_bytes
)
return
pdf_parse_union
(
dataset
,
model_list
,
imageWriter
,
"txt"
,
SupportedPdfParseMethod
.
TXT
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
debug_mode
=
debug_mode
,
...
...
magic_pdf/pdf_parse_union_core_v2.py
View file @
6d571e2e
This diff is collapsed.
Click to expand it.
magic_pdf/pipe/AbsPipe.py
View file @
6d571e2e
...
...
@@ -17,7 +17,7 @@ class AbsPipe(ABC):
PIP_TXT
=
"txt"
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
self
.
pdf_bytes
=
pdf_bytes
self
.
model_list
=
model_list
self
.
image_writer
=
image_writer
...
...
@@ -26,6 +26,9 @@ class AbsPipe(ABC):
self
.
start_page_id
=
start_page_id
self
.
end_page_id
=
end_page_id
self
.
lang
=
lang
self
.
layout_model
=
layout_model
self
.
formula_enable
=
formula_enable
self
.
table_enable
=
table_enable
def
get_compress_pdf_mid_data
(
self
):
return
JsonCompressor
.
compress_json
(
self
.
pdf_mid_data
)
...
...
@@ -95,9 +98,7 @@ class AbsPipe(ABC):
"""
pdf_mid_data
=
JsonCompressor
.
decompress_json
(
compressed_pdf_mid_data
)
pdf_info_list
=
pdf_mid_data
[
"pdf_info"
]
parse_type
=
pdf_mid_data
[
"_parse_type"
]
lang
=
pdf_mid_data
.
get
(
"_lang"
,
None
)
content_list
=
union_make
(
pdf_info_list
,
MakeMode
.
STANDARD_FORMAT
,
drop_mode
,
img_buket_path
,
parse_type
,
lang
)
content_list
=
union_make
(
pdf_info_list
,
MakeMode
.
STANDARD_FORMAT
,
drop_mode
,
img_buket_path
)
return
content_list
@
staticmethod
...
...
@@ -107,9 +108,7 @@ class AbsPipe(ABC):
"""
pdf_mid_data
=
JsonCompressor
.
decompress_json
(
compressed_pdf_mid_data
)
pdf_info_list
=
pdf_mid_data
[
"pdf_info"
]
parse_type
=
pdf_mid_data
[
"_parse_type"
]
lang
=
pdf_mid_data
.
get
(
"_lang"
,
None
)
md_content
=
union_make
(
pdf_info_list
,
md_make_mode
,
drop_mode
,
img_buket_path
,
parse_type
,
lang
)
md_content
=
union_make
(
pdf_info_list
,
md_make_mode
,
drop_mode
,
img_buket_path
)
return
md_content
magic_pdf/pipe/OCRPipe.py
View file @
6d571e2e
...
...
@@ -10,8 +10,10 @@ from magic_pdf.user_api import parse_ocr_pdf
class
OCRPipe
(
AbsPipe
):
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
super
().
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
)
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
super
().
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
def
pipe_classify
(
self
):
pass
...
...
@@ -19,12 +21,14 @@ class OCRPipe(AbsPipe):
def
pipe_analyze
(
self
):
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
True
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_parse
(
self
):
self
.
pdf_mid_data
=
parse_ocr_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
result
=
super
().
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
...
...
magic_pdf/pipe/TXTPipe.py
View file @
6d571e2e
...
...
@@ -11,8 +11,10 @@ from magic_pdf.user_api import parse_txt_pdf
class
TXTPipe
(
AbsPipe
):
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
super
().
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
)
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
super
().
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
def
pipe_classify
(
self
):
pass
...
...
@@ -20,12 +22,14 @@ class TXTPipe(AbsPipe):
def
pipe_analyze
(
self
):
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
False
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_parse
(
self
):
self
.
pdf_mid_data
=
parse_txt_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
result
=
super
().
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
...
...
magic_pdf/pipe/UNIPipe.py
View file @
6d571e2e
...
...
@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
class
UNIPipe
(
AbsPipe
):
def
__init__
(
self
,
pdf_bytes
:
bytes
,
jso_useful_key
:
dict
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
self
.
pdf_type
=
jso_useful_key
[
"_pdf_type"
]
super
().
__init__
(
pdf_bytes
,
jso_useful_key
[
"model_list"
],
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
)
super
().
__init__
(
pdf_bytes
,
jso_useful_key
[
"model_list"
],
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
if
len
(
self
.
model_list
)
==
0
:
self
.
input_model_is_empty
=
True
else
:
...
...
@@ -29,18 +31,21 @@ class UNIPipe(AbsPipe):
if
self
.
pdf_type
==
self
.
PIP_TXT
:
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
False
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
elif
self
.
pdf_type
==
self
.
PIP_OCR
:
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
True
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_parse
(
self
):
if
self
.
pdf_type
==
self
.
PIP_TXT
:
self
.
pdf_mid_data
=
parse_union_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
input_model_is_empty
=
self
.
input_model_is_empty
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
elif
self
.
pdf_type
==
self
.
PIP_OCR
:
self
.
pdf_mid_data
=
parse_ocr_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
...
...
magic_pdf/pre_proc/ocr_detect_all_bboxes.py
View file @
6d571e2e
from
loguru
import
logger
from
magic_pdf.libs.boxbase
import
get_minbox_if_overlap_by_ratio
,
calculate_overlap_area_in_bbox1_area_ratio
,
\
calculate_iou
calculate_iou
,
calculate_vertical_projection_overlap_ratio
from
magic_pdf.libs.drop_tag
import
DropTag
from
magic_pdf.libs.ocr_content_type
import
BlockType
from
magic_pdf.pre_proc.remove_bbox_overlap
import
remove_overlap_between_bbox_for_block
...
...
@@ -60,29 +60,34 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
return
all_bboxes
,
all_discarded_blocks
,
drop_reasons
def
ocr_prepare_bboxes_for_layout_split_v2
(
img_blocks
,
table_blocks
,
discarded_blocks
,
text_blocks
,
title_blocks
,
interline_equation_blocks
,
page_w
,
page_h
):
all_bboxes
=
[]
all_discarded_blocks
=
[]
for
image
in
img_blocks
:
x0
,
y0
,
x1
,
y1
=
image
[
'bbox'
]
all_bboxes
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
BlockType
.
Image
,
None
,
None
,
None
,
None
,
image
[
"score"
]])
def
add_bboxes
(
blocks
,
block_type
,
bboxes
):
for
block
in
blocks
:
x0
,
y0
,
x1
,
y1
=
block
[
'bbox'
]
if
block_type
in
[
BlockType
.
ImageBody
,
BlockType
.
ImageCaption
,
BlockType
.
ImageFootnote
,
BlockType
.
TableBody
,
BlockType
.
TableCaption
,
BlockType
.
TableFootnote
]:
bboxes
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
block_type
,
None
,
None
,
None
,
None
,
block
[
"score"
],
block
[
"group_id"
]])
else
:
bboxes
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
block_type
,
None
,
None
,
None
,
None
,
block
[
"score"
]])
for
table
in
table_blocks
:
x0
,
y0
,
x1
,
y1
=
table
[
'bbox'
]
all_bboxes
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
BlockType
.
Table
,
None
,
None
,
None
,
None
,
table
[
"score"
]])
for
text
in
text_blocks
:
x0
,
y0
,
x1
,
y1
=
text
[
'bbox'
]
all_bboxes
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
BlockType
.
Text
,
None
,
None
,
None
,
None
,
text
[
"score"
]])
def
ocr_prepare_bboxes_for_layout_split_v2
(
img_body_blocks
,
img_caption_blocks
,
img_footnote_blocks
,
table_body_blocks
,
table_caption_blocks
,
table_footnote_blocks
,
discarded_blocks
,
text_blocks
,
title_blocks
,
interline_equation_blocks
,
page_w
,
page_h
):
all_bboxes
=
[]
for
title
in
title_blocks
:
x0
,
y0
,
x1
,
y1
=
title
[
'bbox'
]
all_bboxes
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
BlockType
.
Title
,
None
,
None
,
None
,
None
,
title
[
"score"
]])
for
interline_equation
in
interline_equation_blocks
:
x0
,
y0
,
x1
,
y1
=
interline_equation
[
'bbox'
]
all_bboxes
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
BlockType
.
InterlineEquation
,
None
,
None
,
None
,
None
,
interline_equation
[
"score"
]])
add_bboxes
(
img_body_blocks
,
BlockType
.
ImageBody
,
all_bboxes
)
add_bboxes
(
img_caption_blocks
,
BlockType
.
ImageCaption
,
all_bboxes
)
add_bboxes
(
img_footnote_blocks
,
BlockType
.
ImageFootnote
,
all_bboxes
)
add_bboxes
(
table_body_blocks
,
BlockType
.
TableBody
,
all_bboxes
)
add_bboxes
(
table_caption_blocks
,
BlockType
.
TableCaption
,
all_bboxes
)
add_bboxes
(
table_footnote_blocks
,
BlockType
.
TableFootnote
,
all_bboxes
)
add_bboxes
(
text_blocks
,
BlockType
.
Text
,
all_bboxes
)
add_bboxes
(
title_blocks
,
BlockType
.
Title
,
all_bboxes
)
add_bboxes
(
interline_equation_blocks
,
BlockType
.
InterlineEquation
,
all_bboxes
)
'''block嵌套问题解决'''
'''文本框与标题框重叠,优先信任文本框'''
...
...
@@ -96,23 +101,47 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
'''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
# 通过后续大框套小框逻辑删除
'''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)'''
'''discarded_blocks'''
all_discarded_blocks
=
[]
add_bboxes
(
discarded_blocks
,
BlockType
.
Discarded
,
all_discarded_blocks
)
'''footnote识别:宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的'''
footnote_blocks
=
[]
for
discarded
in
discarded_blocks
:
x0
,
y0
,
x1
,
y1
=
discarded
[
'bbox'
]
all_discarded_blocks
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
BlockType
.
Discarded
,
None
,
None
,
None
,
None
,
discarded
[
"score"
]])
# 将footnote加入到all_bboxes中,用来计算layout
# if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
# all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Footnote, None, None, None, None, discarded["score"]])
if
(
x1
-
x0
)
>
(
page_w
/
3
)
and
(
y1
-
y0
)
>
10
and
y0
>
(
page_h
/
2
):
footnote_blocks
.
append
([
x0
,
y0
,
x1
,
y1
])
'''移除在footnote下面的任何框'''
need_remove_blocks
=
find_blocks_under_footnote
(
all_bboxes
,
footnote_blocks
)
if
len
(
need_remove_blocks
)
>
0
:
for
block
in
need_remove_blocks
:
all_bboxes
.
remove
(
block
)
all_discarded_blocks
.
append
(
block
)
'''经过以上处理后,还存在大框套小框的情况,则删除小框'''
all_bboxes
=
remove_overlaps_min_blocks
(
all_bboxes
)
all_discarded_blocks
=
remove_overlaps_min_blocks
(
all_discarded_blocks
)
'''将剩余的bbox做分离处理,防止后面分layout时出错'''
#
all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes)
all_bboxes
,
drop_reasons
=
remove_overlap_between_bbox_for_block
(
all_bboxes
)
return
all_bboxes
,
all_discarded_blocks
def
find_blocks_under_footnote
(
all_bboxes
,
footnote_blocks
):
need_remove_blocks
=
[]
for
block
in
all_bboxes
:
block_x0
,
block_y0
,
block_x1
,
block_y1
=
block
[:
4
]
for
footnote_bbox
in
footnote_blocks
:
footnote_x0
,
footnote_y0
,
footnote_x1
,
footnote_y1
=
footnote_bbox
# 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
if
block_y0
>=
footnote_y1
and
calculate_vertical_projection_overlap_ratio
((
block_x0
,
block_y0
,
block_x1
,
block_y1
),
footnote_bbox
)
>=
0.8
:
if
block
not
in
need_remove_blocks
:
need_remove_blocks
.
append
(
block
)
break
return
need_remove_blocks
def
fix_interline_equation_overlap_text_blocks_with_hi_iou
(
all_bboxes
):
# 先提取所有text和interline block
text_blocks
=
[]
...
...
magic_pdf/pre_proc/ocr_dict_merge.py
View file @
6d571e2e
...
...
@@ -49,8 +49,7 @@ def merge_spans_to_line(spans):
continue
# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if
__is_overlaps_y_exceeds_threshold
(
span
[
'bbox'
],
current_line
[
-
1
][
'bbox'
]):
if
__is_overlaps_y_exceeds_threshold
(
span
[
'bbox'
],
current_line
[
-
1
][
'bbox'
],
0.5
):
current_line
.
append
(
span
)
else
:
# 否则,开始新行
...
...
@@ -154,6 +153,11 @@ def fill_spans_in_blocks(blocks, spans, radio):
'type'
:
block_type
,
'bbox'
:
block_bbox
,
}
if
block_type
in
[
BlockType
.
ImageBody
,
BlockType
.
ImageCaption
,
BlockType
.
ImageFootnote
,
BlockType
.
TableBody
,
BlockType
.
TableCaption
,
BlockType
.
TableFootnote
]:
block_dict
[
"group_id"
]
=
block
[
-
1
]
block_spans
=
[]
for
span
in
spans
:
span_bbox
=
span
[
'bbox'
]
...
...
@@ -202,6 +206,27 @@ def fix_block_spans(block_with_spans, img_blocks, table_blocks):
return
fix_blocks
def
fix_block_spans_v2
(
block_with_spans
):
"""1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系
需要将caption和footnote的text_span放入相应img_block和table_block内的
caption_block和footnote_block中 2、同时需要删除block中的spans字段."""
fix_blocks
=
[]
for
block
in
block_with_spans
:
block_type
=
block
[
'type'
]
if
block_type
in
[
BlockType
.
Text
,
BlockType
.
Title
,
BlockType
.
ImageCaption
,
BlockType
.
ImageFootnote
,
BlockType
.
TableCaption
,
BlockType
.
TableFootnote
]:
block
=
fix_text_block
(
block
)
elif
block_type
in
[
BlockType
.
InterlineEquation
,
BlockType
.
ImageBody
,
BlockType
.
TableBody
]:
block
=
fix_interline_block
(
block
)
else
:
continue
fix_blocks
.
append
(
block
)
return
fix_blocks
def
fix_discarded_block
(
discarded_block_with_spans
):
fix_discarded_blocks
=
[]
for
block
in
discarded_block_with_spans
:
...
...
magic_pdf/resources/model_config/model_configs.yaml
View file @
6d571e2e
config
:
device
:
cpu
layout
:
True
formula
:
True
table_config
:
model
:
TableMaster
is_table_recog_enable
:
False
max_time
:
400
weights
:
layout
:
Layout/model_final.pth
mfd
:
MFD/weights.pt
mfr
:
MFR/unimernet_small
layoutlmv3
:
Layout/LayoutLMv3/model_final.pth
doclayout_yolo
:
Layout/YOLO/doclayout_yolo_ft.pt
yolo_v8_mfd
:
MFD/YOLO/yolo_v8_ft.pt
unimernet_small
:
MFR/unimernet_small
struct_eqtable
:
TabRec/StructEqTable
TableMaster
:
TabRec/TableMaster
\ No newline at end of file
tablemaster
:
TabRec/TableMaster
\ No newline at end of file
magic_pdf/tools/cli.py
View file @
6d571e2e
...
...
@@ -52,7 +52,7 @@ without method specified, auto will be used by default.""",
help
=
"""
Input the languages in the pdf (if known) to improve OCR accuracy. Optional.
You should input "Abbreviation" with language form url:
https://paddlepaddle.github.io/PaddleOCR/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
https://paddlepaddle.github.io/PaddleOCR/
latest/
en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
"""
,
default
=
None
,
)
...
...
magic_pdf/tools/common.py
View file @
6d571e2e
...
...
@@ -6,8 +6,8 @@ import click
from
loguru
import
logger
import
magic_pdf.model
as
model_config
from
magic_pdf.libs.draw_bbox
import
(
draw_layout_bbox
,
draw_
span
_bbox
,
draw_model_bbox
,
draw_
line_sort
_bbox
)
from
magic_pdf.libs.draw_bbox
import
(
draw_layout_bbox
,
draw_
line_sort
_bbox
,
draw_model_bbox
,
draw_
span
_bbox
)
from
magic_pdf.libs.MakeContentConfig
import
DropMode
,
MakeMode
from
magic_pdf.pipe.OCRPipe
import
OCRPipe
from
magic_pdf.pipe.TXTPipe
import
TXTPipe
...
...
@@ -46,10 +46,12 @@ def do_parse(
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
):
if
debug_able
:
logger
.
warning
(
'debug mode is on'
)
# f_dump_content_list = True
f_draw_model_bbox
=
True
f_draw_line_sort_bbox
=
True
...
...
@@ -64,13 +66,16 @@ def do_parse(
if
parse_method
==
'auto'
:
jso_useful_key
=
{
'_pdf_type'
:
''
,
'model_list'
:
model_list
}
pipe
=
UNIPipe
(
pdf_bytes
,
jso_useful_key
,
image_writer
,
is_debug
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
)
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
elif
parse_method
==
'txt'
:
pipe
=
TXTPipe
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
)
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
elif
parse_method
==
'ocr'
:
pipe
=
OCRPipe
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
)
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
else
:
logger
.
error
(
'unknown parse method'
)
exit
(
1
)
...
...
magic_pdf/user_api.py
View file @
6d571e2e
...
...
@@ -101,11 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
if
pdf_info_dict
is
None
or
pdf_info_dict
.
get
(
"_need_drop"
,
False
):
logger
.
warning
(
f
"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr"
)
if
input_model_is_empty
:
pdf_models
=
doc_analyze
(
pdf_bytes
,
ocr
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
)
layout_model
=
kwargs
.
get
(
"layout_model"
,
None
)
formula_enable
=
kwargs
.
get
(
"formula_enable"
,
None
)
table_enable
=
kwargs
.
get
(
"table_enable"
,
None
)
pdf_models
=
doc_analyze
(
pdf_bytes
,
ocr
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
,
)
pdf_info_dict
=
parse_pdf
(
parse_pdf_by_ocr
)
if
pdf_info_dict
is
None
:
raise
Exception
(
"Both parse_pdf_by_txt and parse_pdf_by_ocr failed."
)
...
...
magic_pdf/utils/__init__.py
0 → 100644
View file @
6d571e2e
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment