Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
ece7f8d5
Unverified
Commit
ece7f8d5
authored
Oct 15, 2024
by
Kaiwen Liu
Committed by
GitHub
Oct 15, 2024
Browse files
Merge pull request #6 from opendatalab/dev
Dev
parents
98362a6e
702b6ac9
Changes
551
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1165 additions
and
142 deletions
+1165
-142
magic_pdf/libs/ocr_content_type.py
magic_pdf/libs/ocr_content_type.py
+2
-0
magic_pdf/libs/version.py
magic_pdf/libs/version.py
+1
-1
magic_pdf/model/doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+48
-27
magic_pdf/model/magic_model.py
magic_pdf/model/magic_model.py
+115
-46
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+65
-36
magic_pdf/model/pp_structure_v2.py
magic_pdf/model/pp_structure_v2.py
+5
-2
magic_pdf/model/v3/__init__.py
magic_pdf/model/v3/__init__.py
+0
-0
magic_pdf/model/v3/helpers.py
magic_pdf/model/v3/helpers.py
+125
-0
magic_pdf/para/para_split_v3.py
magic_pdf/para/para_split_v3.py
+256
-0
magic_pdf/pdf_parse_by_ocr.py
magic_pdf/pdf_parse_by_ocr.py
+1
-1
magic_pdf/pdf_parse_by_txt.py
magic_pdf/pdf_parse_by_txt.py
+1
-1
magic_pdf/pdf_parse_union_core_v2.py
magic_pdf/pdf_parse_union_core_v2.py
+453
-0
magic_pdf/pipe/AbsPipe.py
magic_pdf/pipe/AbsPipe.py
+8
-3
magic_pdf/pipe/OCRPipe.py
magic_pdf/pipe/OCRPipe.py
+6
-4
magic_pdf/pipe/TXTPipe.py
magic_pdf/pipe/TXTPipe.py
+6
-4
magic_pdf/pipe/UNIPipe.py
magic_pdf/pipe/UNIPipe.py
+11
-7
magic_pdf/pre_proc/ocr_detect_all_bboxes.py
magic_pdf/pre_proc/ocr_detect_all_bboxes.py
+53
-0
magic_pdf/pre_proc/ocr_dict_merge.py
magic_pdf/pre_proc/ocr_dict_merge.py
+1
-2
magic_pdf/resources/model_config/UniMERNet/demo.yaml
magic_pdf/resources/model_config/UniMERNet/demo.yaml
+7
-7
magic_pdf/resources/model_config/model_configs.yaml
magic_pdf/resources/model_config/model_configs.yaml
+1
-1
No files found.
magic_pdf/libs/ocr_content_type.py
View file @
ece7f8d5
...
...
@@ -20,6 +20,8 @@ class BlockType:
InterlineEquation
=
'interline_equation'
Footnote
=
'footnote'
Discarded
=
'discarded'
List
=
'list'
Index
=
'index'
class
CategoryId
:
...
...
magic_pdf/libs/version.py
View file @
ece7f8d5
__version__
=
"0.
7.1
"
__version__
=
"0.
8.0
"
magic_pdf/model/doc_analyze_by_custom_model.py
View file @
ece7f8d5
...
...
@@ -4,6 +4,7 @@ import fitz
import
numpy
as
np
from
loguru
import
logger
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.libs.config_reader
import
get_local_models_dir
,
get_device
,
get_table_recog_config
from
magic_pdf.model.model_list
import
MODEL
import
magic_pdf.model
as
model_config
...
...
@@ -23,7 +24,7 @@ def remove_duplicates_dicts(lst):
return
unique_dicts
def
load_images_from_pdf
(
pdf_bytes
:
bytes
,
dpi
=
200
)
->
list
:
def
load_images_from_pdf
(
pdf_bytes
:
bytes
,
dpi
=
200
,
start_page_id
=
0
,
end_page_id
=
None
)
->
list
:
try
:
from
PIL
import
Image
except
ImportError
:
...
...
@@ -32,18 +33,28 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200) -> list:
images
=
[]
with
fitz
.
open
(
"pdf"
,
pdf_bytes
)
as
doc
:
pdf_page_num
=
doc
.
page_count
end_page_id
=
end_page_id
if
end_page_id
is
not
None
and
end_page_id
>=
0
else
pdf_page_num
-
1
if
end_page_id
>
pdf_page_num
-
1
:
logger
.
warning
(
"end_page_id is out of range, use images length"
)
end_page_id
=
pdf_page_num
-
1
for
index
in
range
(
0
,
doc
.
page_count
):
page
=
doc
[
index
]
mat
=
fitz
.
Matrix
(
dpi
/
72
,
dpi
/
72
)
pm
=
page
.
get_pixmap
(
matrix
=
mat
,
alpha
=
False
)
if
start_page_id
<=
index
<=
end_page_id
:
page
=
doc
[
index
]
mat
=
fitz
.
Matrix
(
dpi
/
72
,
dpi
/
72
)
pm
=
page
.
get_pixmap
(
matrix
=
mat
,
alpha
=
False
)
# If the width or height exceeds 9000 after scaling, do not scale further.
if
pm
.
width
>
9000
or
pm
.
height
>
9000
:
pm
=
page
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
# If the width or height exceeds 9000 after scaling, do not scale further.
if
pm
.
width
>
9000
or
pm
.
height
>
9000
:
pm
=
page
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
img
=
Image
.
frombytes
(
"RGB"
,
(
pm
.
width
,
pm
.
height
),
pm
.
samples
)
img
=
np
.
array
(
img
)
img_dict
=
{
"img"
:
img
,
"width"
:
pm
.
width
,
"height"
:
pm
.
height
}
else
:
img_dict
=
{
"img"
:
[],
"width"
:
0
,
"height"
:
0
}
img
=
Image
.
frombytes
(
"RGB"
,
(
pm
.
width
,
pm
.
height
),
pm
.
samples
)
img
=
np
.
array
(
img
)
img_dict
=
{
"img"
:
img
,
"width"
:
pm
.
width
,
"height"
:
pm
.
height
}
images
.
append
(
img_dict
)
return
images
...
...
@@ -57,14 +68,14 @@ class ModelSingleton:
cls
.
_instance
=
super
().
__new__
(
cls
)
return
cls
.
_instance
def
get_model
(
self
,
ocr
:
bool
,
show_log
:
bool
):
key
=
(
ocr
,
show_log
)
def
get_model
(
self
,
ocr
:
bool
,
show_log
:
bool
,
lang
=
None
):
key
=
(
ocr
,
show_log
,
lang
)
if
key
not
in
self
.
_models
:
self
.
_models
[
key
]
=
custom_model_init
(
ocr
=
ocr
,
show_log
=
show_log
)
self
.
_models
[
key
]
=
custom_model_init
(
ocr
=
ocr
,
show_log
=
show_log
,
lang
=
lang
)
return
self
.
_models
[
key
]
def
custom_model_init
(
ocr
:
bool
=
False
,
show_log
:
bool
=
False
):
def
custom_model_init
(
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
lang
=
None
):
model
=
None
if
model_config
.
__model_mode__
==
"lite"
:
...
...
@@ -78,7 +89,7 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
model_init_start
=
time
.
time
()
if
model
==
MODEL
.
Paddle
:
from
magic_pdf.model.pp_structure_v2
import
CustomPaddleModel
custom_model
=
CustomPaddleModel
(
ocr
=
ocr
,
show_log
=
show_log
)
custom_model
=
CustomPaddleModel
(
ocr
=
ocr
,
show_log
=
show_log
,
lang
=
lang
)
elif
model
==
MODEL
.
PEK
:
from
magic_pdf.model.pdf_extract_kit
import
CustomPEKModel
# 从配置文件读取model-dir和device
...
...
@@ -89,7 +100,9 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
"show_log"
:
show_log
,
"models_dir"
:
local_models_dir
,
"device"
:
device
,
"table_config"
:
table_config
}
"table_config"
:
table_config
,
"lang"
:
lang
,
}
custom_model
=
CustomPEKModel
(
**
model_input
)
else
:
logger
.
error
(
"Not allow model_name!"
)
...
...
@@ -104,19 +117,19 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
def
doc_analyze
(
pdf_bytes
:
bytes
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
):
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
model_manager
=
ModelSingleton
()
custom_model
=
model_manager
.
get_model
(
ocr
,
show_log
)
images
=
load_images_from_pdf
(
pdf_bytes
)
custom_model
=
model_manager
.
get_model
(
ocr
,
show_log
,
lang
)
# end_page_id = end_page_id if end_page_id else len(images) - 1
end_page_id
=
end_page_id
if
end_page_id
is
not
None
and
end_page_id
>=
0
else
len
(
images
)
-
1
with
fitz
.
open
(
"pdf"
,
pdf_bytes
)
as
doc
:
pdf_page_num
=
doc
.
page_count
end_page_id
=
end_page_id
if
end_page_id
is
not
None
and
end_page_id
>=
0
else
pdf_page_num
-
1
if
end_page_id
>
pdf_page_num
-
1
:
logger
.
warning
(
"end_page_id is out of range, use images length"
)
end_page_id
=
pdf_page_num
-
1
if
end_page_id
>
len
(
images
)
-
1
:
logger
.
warning
(
"end_page_id is out of range, use images length"
)
end_page_id
=
len
(
images
)
-
1
images
=
load_images_from_pdf
(
pdf_bytes
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
)
model_json
=
[]
doc_analyze_start
=
time
.
time
()
...
...
@@ -132,7 +145,15 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
page_info
=
{
"page_no"
:
index
,
"height"
:
page_height
,
"width"
:
page_width
}
page_dict
=
{
"layout_dets"
:
result
,
"page_info"
:
page_info
}
model_json
.
append
(
page_dict
)
doc_analyze_cost
=
time
.
time
()
-
doc_analyze_start
logger
.
info
(
f
"doc analyze cost:
{
doc_analyze_cost
}
"
)
gc_start
=
time
.
time
()
clean_memory
()
gc_time
=
round
(
time
.
time
()
-
gc_start
,
2
)
logger
.
info
(
f
"gc time:
{
gc_time
}
"
)
doc_analyze_time
=
round
(
time
.
time
()
-
doc_analyze_start
,
2
)
doc_analyze_speed
=
round
(
(
end_page_id
+
1
-
start_page_id
)
/
doc_analyze_time
,
2
)
logger
.
info
(
f
"doc analyze time:
{
round
(
time
.
time
()
-
doc_analyze_start
,
2
)
}
,"
f
" speed:
{
doc_analyze_speed
}
pages/second"
)
return
model_json
magic_pdf/model/magic_model.py
View file @
ece7f8d5
import
json
from
magic_pdf.libs.boxbase
import
(
_is_in
,
_is_part_overlap
,
bbox_distance
,
bbox_relative_pos
,
calculate_iou
,
calculate_overlap_area_in_bbox1_area_ratio
)
bbox_relative_pos
,
box_area
,
calculate_iou
,
calculate_overlap_area_in_bbox1_area_ratio
,
get_overlap_area
)
from
magic_pdf.libs.commons
import
fitz
,
join_path
from
magic_pdf.libs.coordinate_transform
import
get_scale_ratio
from
magic_pdf.libs.local_math
import
float_gt
...
...
@@ -12,6 +13,7 @@ from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from
magic_pdf.rw.DiskReaderWriter
import
DiskReaderWriter
CAPATION_OVERLAP_AREA_RATIO
=
0.6
MERGE_BOX_OVERLAP_AREA_RATIO
=
1.1
class
MagicModel
:
...
...
@@ -108,6 +110,24 @@ class MagicModel:
self
.
__fix_by_remove_high_iou_and_low_confidence
()
self
.
__fix_footnote
()
def
_bbox_distance
(
self
,
bbox1
,
bbox2
):
left
,
right
,
bottom
,
top
=
bbox_relative_pos
(
bbox1
,
bbox2
)
flags
=
[
left
,
right
,
bottom
,
top
]
count
=
sum
([
1
if
v
else
0
for
v
in
flags
])
if
count
>
1
:
return
float
(
'inf'
)
if
left
or
right
:
l1
=
bbox1
[
3
]
-
bbox1
[
1
]
l2
=
bbox2
[
3
]
-
bbox2
[
1
]
else
:
l1
=
bbox1
[
2
]
-
bbox1
[
0
]
l2
=
bbox2
[
2
]
-
bbox2
[
0
]
if
l2
>
l1
and
(
l2
-
l1
)
/
l1
>
0.5
:
return
float
(
'inf'
)
return
bbox_distance
(
bbox1
,
bbox2
)
def
__fix_footnote
(
self
):
# 3: figure, 5: table, 7: footnote
for
model_page_info
in
self
.
__model_list
:
...
...
@@ -124,49 +144,51 @@ class MagicModel:
tables
.
append
(
obj
)
if
len
(
footnotes
)
*
len
(
figures
)
==
0
:
continue
dis_figure_footnote
=
{}
dis_table_footnote
=
{}
for
i
in
range
(
len
(
footnotes
)):
for
j
in
range
(
len
(
figures
)):
pos_flag_count
=
sum
(
list
(
map
(
lambda
x
:
1
if
x
else
0
,
bbox_relative_pos
(
footnotes
[
i
][
'bbox'
],
figures
[
j
][
'bbox'
]
),
)
dis_figure_footnote
=
{}
dis_table_footnote
=
{}
for
i
in
range
(
len
(
footnotes
)):
for
j
in
range
(
len
(
figures
)):
pos_flag_count
=
sum
(
list
(
map
(
lambda
x
:
1
if
x
else
0
,
bbox_relative_pos
(
footnotes
[
i
][
'bbox'
],
figures
[
j
][
'bbox'
]
),
)
)
if
pos_flag_count
>
1
:
cont
inue
dis_figure_footnote
[
i
]
=
min
(
bbox_distance
(
figures
[
j
][
'bbox'
],
footnote
s
[
i
]
[
'bbox'
]),
dis_figure_footnote
.
get
(
i
,
float
(
'inf'
)
),
)
for
i
in
range
(
len
(
footnotes
)):
for
j
in
range
(
len
(
tabl
es
)):
pos_flag_count
=
sum
(
list
(
map
(
lambda
x
:
1
if
x
else
0
,
bbox_relative_pos
(
footnotes
[
i
][
'bbox'
],
tables
[
j
][
'bbox'
]
),
)
)
if
pos_flag_
co
u
nt
>
1
:
continue
dis_figure_
footnote
[
i
]
=
min
(
self
.
_bbox_distance
(
figures
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]
),
dis_figure_footnote
.
get
(
i
,
float
(
'inf'
)),
)
for
i
in
range
(
len
(
footnot
es
)):
for
j
in
range
(
len
(
tables
)):
pos_flag_count
=
sum
(
list
(
map
(
lambda
x
:
1
if
x
else
0
,
bbox_relative_pos
(
footnotes
[
i
][
'bbox'
],
tables
[
j
][
'bbox'
]
)
,
)
)
if
pos_flag_count
>
1
:
continue
)
if
pos_flag_count
>
1
:
continue
dis_table_footnote
[
i
]
=
min
(
bbox_distance
(
tables
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]),
dis_table_footnote
.
get
(
i
,
float
(
'inf'
)),
)
for
i
in
range
(
len
(
footnotes
)):
if
dis_table_footnote
.
get
(
i
,
float
(
'inf'
))
>
dis_figure_footnote
[
i
]:
footnotes
[
i
][
'category_id'
]
=
CategoryId
.
ImageFootnote
dis_table_footnote
[
i
]
=
min
(
self
.
_bbox_distance
(
tables
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]),
dis_table_footnote
.
get
(
i
,
float
(
'inf'
)),
)
for
i
in
range
(
len
(
footnotes
)):
if
i
not
in
dis_figure_footnote
:
continue
if
dis_table_footnote
.
get
(
i
,
float
(
'inf'
))
>
dis_figure_footnote
[
i
]:
footnotes
[
i
][
'category_id'
]
=
CategoryId
.
ImageFootnote
def
__reduct_overlap
(
self
,
bboxes
):
N
=
len
(
bboxes
)
...
...
@@ -191,6 +213,44 @@ class MagicModel:
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离
"""
def
search_overlap_between_boxes
(
subject_idx
,
object_idx
):
idxes
=
[
subject_idx
,
object_idx
]
x0s
=
[
all_bboxes
[
idx
][
'bbox'
][
0
]
for
idx
in
idxes
]
y0s
=
[
all_bboxes
[
idx
][
'bbox'
][
1
]
for
idx
in
idxes
]
x1s
=
[
all_bboxes
[
idx
][
'bbox'
][
2
]
for
idx
in
idxes
]
y1s
=
[
all_bboxes
[
idx
][
'bbox'
][
3
]
for
idx
in
idxes
]
merged_bbox
=
[
min
(
x0s
),
min
(
y0s
),
max
(
x1s
),
max
(
y1s
),
]
ratio
=
0
other_objects
=
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
not
in
(
object_category_id
,
subject_category_id
),
self
.
__model_list
[
page_no
][
'layout_dets'
],
),
)
)
for
other_object
in
other_objects
:
ratio
=
max
(
ratio
,
get_overlap_area
(
merged_bbox
,
other_object
[
'bbox'
]
)
*
1.0
/
box_area
(
all_bboxes
[
object_idx
][
'bbox'
])
)
if
ratio
>=
MERGE_BOX_OVERLAP_AREA_RATIO
:
break
return
ratio
def
may_find_other_nearest_bbox
(
subject_idx
,
object_idx
):
ret
=
float
(
'inf'
)
...
...
@@ -299,7 +359,16 @@ class MagicModel:
):
continue
dis
[
i
][
j
]
=
bbox_distance
(
all_bboxes
[
i
][
'bbox'
],
all_bboxes
[
j
][
'bbox'
])
subject_idx
,
object_idx
=
i
,
j
if
all_bboxes
[
j
][
'category_id'
]
==
subject_category_id
:
subject_idx
,
object_idx
=
j
,
i
if
search_overlap_between_boxes
(
subject_idx
,
object_idx
)
>=
MERGE_BOX_OVERLAP_AREA_RATIO
:
dis
[
i
][
j
]
=
float
(
'inf'
)
dis
[
j
][
i
]
=
dis
[
i
][
j
]
continue
dis
[
i
][
j
]
=
self
.
_bbox_distance
(
all_bboxes
[
subject_idx
][
'bbox'
],
all_bboxes
[
object_idx
][
'bbox'
])
dis
[
j
][
i
]
=
dis
[
i
][
j
]
used
=
set
()
...
...
@@ -627,13 +696,13 @@ class MagicModel:
span
[
'type'
]
=
ContentType
.
Image
elif
category_id
==
5
:
# 获取table模型结果
latex
=
layout_det
.
get
(
"
latex
"
,
None
)
html
=
layout_det
.
get
(
"
html
"
,
None
)
latex
=
layout_det
.
get
(
'
latex
'
,
None
)
html
=
layout_det
.
get
(
'
html
'
,
None
)
if
latex
:
span
[
"
latex
"
]
=
latex
span
[
'
latex
'
]
=
latex
elif
html
:
span
[
"
html
"
]
=
html
span
[
"
type
"
]
=
ContentType
.
Table
span
[
'
html
'
]
=
html
span
[
'
type
'
]
=
ContentType
.
Table
elif
category_id
==
13
:
span
[
'content'
]
=
layout_det
[
'latex'
]
span
[
'type'
]
=
ContentType
.
InlineEquation
...
...
magic_pdf/model/pdf_extract_kit.py
View file @
ece7f8d5
...
...
@@ -3,9 +3,11 @@ import os
import
time
from
magic_pdf.libs.Constants
import
*
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.model.model_list
import
AtomicModel
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
os
.
environ
[
'YOLO_VERBOSE'
]
=
'False'
# disable yolo logger
try
:
import
cv2
import
yaml
...
...
@@ -32,7 +34,7 @@ except ImportError as e:
exit
(
1
)
from
magic_pdf.model.pek_sub_modules.layoutlmv3.model_init
import
Layoutlmv3_Predictor
from
magic_pdf.model.pek_sub_modules.post_process
import
get_croped_image
,
latex_rm_whitespace
from
magic_pdf.model.pek_sub_modules.post_process
import
latex_rm_whitespace
from
magic_pdf.model.pek_sub_modules.self_modify
import
ModifiedPaddleOCR
from
magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel
import
StructTableModel
from
magic_pdf.model.ppTableModel
import
ppTableModel
...
...
@@ -58,12 +60,13 @@ def mfd_model_init(weight):
def
mfr_model_init
(
weight_dir
,
cfg_path
,
_device_
=
'cpu'
):
args
=
argparse
.
Namespace
(
cfg_path
=
cfg_path
,
options
=
None
)
cfg
=
Config
(
args
)
cfg
.
config
.
model
.
pretrained
=
os
.
path
.
join
(
weight_dir
,
"pytorch_model.
bin
"
)
cfg
.
config
.
model
.
pretrained
=
os
.
path
.
join
(
weight_dir
,
"pytorch_model.
pth
"
)
cfg
.
config
.
model
.
model_config
.
model_name
=
weight_dir
cfg
.
config
.
model
.
tokenizer_config
.
path
=
weight_dir
task
=
tasks
.
setup_task
(
cfg
)
model
=
task
.
build_model
(
cfg
)
model
=
model
.
to
(
_device_
)
model
.
to
(
_device_
)
model
.
eval
()
vis_processor
=
load_processor
(
'formula_image_eval'
,
cfg
.
config
.
datasets
.
formula_rec_eval
.
vis_processor
.
eval
)
mfr_transform
=
transforms
.
Compose
([
vis_processor
,
])
return
[
model
,
mfr_transform
]
...
...
@@ -74,8 +77,11 @@ def layout_model_init(weight, config_file, device):
return
model
def
ocr_model_init
(
show_log
:
bool
=
False
,
det_db_box_thresh
=
0.3
):
model
=
ModifiedPaddleOCR
(
show_log
=
show_log
,
det_db_box_thresh
=
det_db_box_thresh
)
def
ocr_model_init
(
show_log
:
bool
=
False
,
det_db_box_thresh
=
0.3
,
lang
=
None
):
if
lang
is
not
None
:
model
=
ModifiedPaddleOCR
(
show_log
=
show_log
,
det_db_box_thresh
=
det_db_box_thresh
,
lang
=
lang
)
else
:
model
=
ModifiedPaddleOCR
(
show_log
=
show_log
,
det_db_box_thresh
=
det_db_box_thresh
)
return
model
...
...
@@ -134,7 +140,8 @@ def atom_model_init(model_name: str, **kwargs):
elif
model_name
==
AtomicModel
.
OCR
:
atom_model
=
ocr_model_init
(
kwargs
.
get
(
"ocr_show_log"
),
kwargs
.
get
(
"det_db_box_thresh"
)
kwargs
.
get
(
"det_db_box_thresh"
),
kwargs
.
get
(
"lang"
)
)
elif
model_name
==
AtomicModel
.
Table
:
atom_model
=
table_model_init
(
...
...
@@ -150,6 +157,23 @@ def atom_model_init(model_name: str, **kwargs):
return
atom_model
# Unified crop img logic
def
crop_img
(
input_res
,
input_pil_img
,
crop_paste_x
=
0
,
crop_paste_y
=
0
):
crop_xmin
,
crop_ymin
=
int
(
input_res
[
'poly'
][
0
]),
int
(
input_res
[
'poly'
][
1
])
crop_xmax
,
crop_ymax
=
int
(
input_res
[
'poly'
][
4
]),
int
(
input_res
[
'poly'
][
5
])
# Create a white background with an additional width and height of 50
crop_new_width
=
crop_xmax
-
crop_xmin
+
crop_paste_x
*
2
crop_new_height
=
crop_ymax
-
crop_ymin
+
crop_paste_y
*
2
return_image
=
Image
.
new
(
'RGB'
,
(
crop_new_width
,
crop_new_height
),
'white'
)
# Crop image
crop_box
=
(
crop_xmin
,
crop_ymin
,
crop_xmax
,
crop_ymax
)
cropped_img
=
input_pil_img
.
crop
(
crop_box
)
return_image
.
paste
(
cropped_img
,
(
crop_paste_x
,
crop_paste_y
))
return_list
=
[
crop_paste_x
,
crop_paste_y
,
crop_xmin
,
crop_ymin
,
crop_xmax
,
crop_ymax
,
crop_new_width
,
crop_new_height
]
return
return_image
,
return_list
class
CustomPEKModel
:
def
__init__
(
self
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
**
kwargs
):
...
...
@@ -177,9 +201,10 @@ class CustomPEKModel:
self
.
table_max_time
=
self
.
table_config
.
get
(
"max_time"
,
TABLE_MAX_TIME_VALUE
)
self
.
table_model_type
=
self
.
table_config
.
get
(
"model"
,
TABLE_MASTER
)
self
.
apply_ocr
=
ocr
self
.
lang
=
kwargs
.
get
(
"lang"
,
None
)
logger
.
info
(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}"
.
format
(
self
.
apply_layout
,
self
.
apply_formula
,
self
.
apply_ocr
,
self
.
apply_table
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table:
{}, lang:
{}"
.
format
(
self
.
apply_layout
,
self
.
apply_formula
,
self
.
apply_ocr
,
self
.
apply_table
,
self
.
lang
)
)
assert
self
.
apply_layout
,
"DocAnalysis must contain layout model."
...
...
@@ -225,11 +250,13 @@ class CustomPEKModel:
)
# 初始化ocr
if
self
.
apply_ocr
:
# self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
self
.
ocr_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
OCR
,
ocr_show_log
=
show_log
,
det_db_box_thresh
=
0.3
det_db_box_thresh
=
0.3
,
lang
=
self
.
lang
)
# init table model
if
self
.
apply_table
:
...
...
@@ -243,10 +270,13 @@ class CustomPEKModel:
table_max_time
=
self
.
table_max_time
,
device
=
self
.
device
)
logger
.
info
(
'DocAnalysis init done!'
)
def
__call__
(
self
,
image
):
page_start
=
time
.
time
()
latex_filling_list
=
[]
mf_image_list
=
[]
...
...
@@ -254,11 +284,15 @@ class CustomPEKModel:
layout_start
=
time
.
time
()
layout_res
=
self
.
layout_model
(
image
,
ignore_catids
=
[])
layout_cost
=
round
(
time
.
time
()
-
layout_start
,
2
)
logger
.
info
(
f
"layout detection cost:
{
layout_cost
}
"
)
logger
.
info
(
f
"layout detection time:
{
layout_cost
}
"
)
pil_img
=
Image
.
fromarray
(
image
)
if
self
.
apply_formula
:
# 公式检测
mfd_start
=
time
.
time
()
mfd_res
=
self
.
mfd_model
.
predict
(
image
,
imgsz
=
1888
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
)[
0
]
logger
.
info
(
f
"mfd time:
{
round
(
time
.
time
()
-
mfd_start
,
2
)
}
"
)
for
xyxy
,
conf
,
cla
in
zip
(
mfd_res
.
boxes
.
xyxy
.
cpu
(),
mfd_res
.
boxes
.
conf
.
cpu
(),
mfd_res
.
boxes
.
cls
.
cpu
()):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
new_item
=
{
...
...
@@ -269,7 +303,8 @@ class CustomPEKModel:
}
layout_res
.
append
(
new_item
)
latex_filling_list
.
append
(
new_item
)
bbox_img
=
get_croped_image
(
Image
.
fromarray
(
image
),
[
xmin
,
ymin
,
xmax
,
ymax
])
# bbox_img = get_croped_image(pil_img, [xmin, ymin, xmax, ymax])
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
ymax
))
mf_image_list
.
append
(
bbox_img
)
# 公式识别
...
...
@@ -279,7 +314,8 @@ class CustomPEKModel:
mfr_res
=
[]
for
mf_img
in
dataloader
:
mf_img
=
mf_img
.
to
(
self
.
device
)
output
=
self
.
mfr_model
.
generate
({
'image'
:
mf_img
})
with
torch
.
no_grad
():
output
=
self
.
mfr_model
.
generate
({
'image'
:
mf_img
})
mfr_res
.
extend
(
output
[
'pred_str'
])
for
res
,
latex
in
zip
(
latex_filling_list
,
mfr_res
):
res
[
'latex'
]
=
latex_rm_whitespace
(
latex
)
...
...
@@ -301,23 +337,14 @@ class CustomPEKModel:
elif
int
(
res
[
'category_id'
])
in
[
5
]:
table_res_list
.
append
(
res
)
# Unified crop img logic
def
crop_img
(
input_res
,
input_pil_img
,
crop_paste_x
=
0
,
crop_paste_y
=
0
):
crop_xmin
,
crop_ymin
=
int
(
input_res
[
'poly'
][
0
]),
int
(
input_res
[
'poly'
][
1
])
crop_xmax
,
crop_ymax
=
int
(
input_res
[
'poly'
][
4
]),
int
(
input_res
[
'poly'
][
5
])
# Create a white background with an additional width and height of 50
crop_new_width
=
crop_xmax
-
crop_xmin
+
crop_paste_x
*
2
crop_new_height
=
crop_ymax
-
crop_ymin
+
crop_paste_y
*
2
return_image
=
Image
.
new
(
'RGB'
,
(
crop_new_width
,
crop_new_height
),
'white'
)
# Crop image
crop_box
=
(
crop_xmin
,
crop_ymin
,
crop_xmax
,
crop_ymax
)
cropped_img
=
input_pil_img
.
crop
(
crop_box
)
return_image
.
paste
(
cropped_img
,
(
crop_paste_x
,
crop_paste_y
))
return_list
=
[
crop_paste_x
,
crop_paste_y
,
crop_xmin
,
crop_ymin
,
crop_xmax
,
crop_ymax
,
crop_new_width
,
crop_new_height
]
return
return_image
,
return_list
pil_img
=
Image
.
fromarray
(
image
)
if
torch
.
cuda
.
is_available
():
properties
=
torch
.
cuda
.
get_device_properties
(
self
.
device
)
total_memory
=
properties
.
total_memory
/
(
1024
**
3
)
# 将字节转换为 GB
if
total_memory
<=
10
:
gc_start
=
time
.
time
()
clean_memory
()
gc_time
=
round
(
time
.
time
()
-
gc_start
,
2
)
logger
.
info
(
f
"gc time:
{
gc_time
}
"
)
# ocr识别
if
self
.
apply_ocr
:
...
...
@@ -367,7 +394,7 @@ class CustomPEKModel:
})
ocr_cost
=
round
(
time
.
time
()
-
ocr_start
,
2
)
logger
.
info
(
f
"ocr
cost
:
{
ocr_cost
}
"
)
logger
.
info
(
f
"ocr
time
:
{
ocr_cost
}
"
)
# 表格识别 table recognition
if
self
.
apply_table
:
...
...
@@ -375,7 +402,7 @@ class CustomPEKModel:
for
res
in
table_res_list
:
new_image
,
_
=
crop_img
(
res
,
pil_img
)
single_table_start_time
=
time
.
time
()
logger
.
info
(
"------------------table recognition processing begins-----------------"
)
#
logger.info("------------------table recognition processing begins-----------------")
latex_code
=
None
html_code
=
None
if
self
.
table_model_type
==
STRUCT_EQTABLE
:
...
...
@@ -383,8 +410,9 @@ class CustomPEKModel:
latex_code
=
self
.
table_model
.
image2latex
(
new_image
)[
0
]
else
:
html_code
=
self
.
table_model
.
img2html
(
new_image
)
run_time
=
time
.
time
()
-
single_table_start_time
logger
.
info
(
f
"------------table recognition processing ends within
{
run_time
}
s-----"
)
#
logger.info(f"------------table recognition processing ends within {run_time}s-----")
if
run_time
>
self
.
table_max_time
:
logger
.
warning
(
f
"------------table recognition processing exceeds max time
{
self
.
table_max_time
}
s----------"
)
# 判断是否返回正常
...
...
@@ -395,12 +423,13 @@ class CustomPEKModel:
if
expected_ending
:
res
[
"latex"
]
=
latex_code
else
:
logger
.
warning
(
f
"
------------
table recognition processing fails
----------
"
)
logger
.
warning
(
f
"table recognition processing fails
, not found expected LaTeX table end
"
)
elif
html_code
:
res
[
"html"
]
=
html_code
else
:
logger
.
warning
(
f
"------------table recognition processing fails----------"
)
table_cost
=
round
(
time
.
time
()
-
table_start
,
2
)
logger
.
info
(
f
"table cost:
{
table_cost
}
"
)
logger
.
warning
(
f
"table recognition processing fails, not get latex or html return"
)
logger
.
info
(
f
"table time:
{
round
(
time
.
time
()
-
table_start
,
2
)
}
"
)
logger
.
info
(
f
"-----page total time:
{
round
(
time
.
time
()
-
page_start
,
2
)
}
-----"
)
return
layout_res
magic_pdf/model/pp_structure_v2.py
View file @
ece7f8d5
...
...
@@ -18,8 +18,11 @@ def region_to_bbox(region):
class
CustomPaddleModel
:
def
__init__
(
self
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
):
self
.
model
=
PPStructure
(
table
=
False
,
ocr
=
ocr
,
show_log
=
show_log
)
def
__init__
(
self
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
lang
=
None
):
if
lang
is
not
None
:
self
.
model
=
PPStructure
(
table
=
False
,
ocr
=
ocr
,
show_log
=
show_log
,
lang
=
lang
)
else
:
self
.
model
=
PPStructure
(
table
=
False
,
ocr
=
ocr
,
show_log
=
show_log
)
def
__call__
(
self
,
img
):
try
:
...
...
projects/web_api/tests
/__init__.py
→
magic_pdf/model/v3
/__init__.py
View file @
ece7f8d5
File moved
magic_pdf/model/v3/helpers.py
0 → 100644
View file @
ece7f8d5
from
collections
import
defaultdict
from
typing
import
List
,
Dict
import
torch
from
transformers
import
LayoutLMv3ForTokenClassification
MAX_LEN
=
510
CLS_TOKEN_ID
=
0
UNK_TOKEN_ID
=
3
EOS_TOKEN_ID
=
2
class
DataCollator
:
def
__call__
(
self
,
features
:
List
[
dict
])
->
Dict
[
str
,
torch
.
Tensor
]:
bbox
=
[]
labels
=
[]
input_ids
=
[]
attention_mask
=
[]
# clip bbox and labels to max length, build input_ids and attention_mask
for
feature
in
features
:
_bbox
=
feature
[
"source_boxes"
]
if
len
(
_bbox
)
>
MAX_LEN
:
_bbox
=
_bbox
[:
MAX_LEN
]
_labels
=
feature
[
"target_index"
]
if
len
(
_labels
)
>
MAX_LEN
:
_labels
=
_labels
[:
MAX_LEN
]
_input_ids
=
[
UNK_TOKEN_ID
]
*
len
(
_bbox
)
_attention_mask
=
[
1
]
*
len
(
_bbox
)
assert
len
(
_bbox
)
==
len
(
_labels
)
==
len
(
_input_ids
)
==
len
(
_attention_mask
)
bbox
.
append
(
_bbox
)
labels
.
append
(
_labels
)
input_ids
.
append
(
_input_ids
)
attention_mask
.
append
(
_attention_mask
)
# add CLS and EOS tokens
for
i
in
range
(
len
(
bbox
)):
bbox
[
i
]
=
[[
0
,
0
,
0
,
0
]]
+
bbox
[
i
]
+
[[
0
,
0
,
0
,
0
]]
labels
[
i
]
=
[
-
100
]
+
labels
[
i
]
+
[
-
100
]
input_ids
[
i
]
=
[
CLS_TOKEN_ID
]
+
input_ids
[
i
]
+
[
EOS_TOKEN_ID
]
attention_mask
[
i
]
=
[
1
]
+
attention_mask
[
i
]
+
[
1
]
# padding to max length
max_len
=
max
(
len
(
x
)
for
x
in
bbox
)
for
i
in
range
(
len
(
bbox
)):
bbox
[
i
]
=
bbox
[
i
]
+
[[
0
,
0
,
0
,
0
]]
*
(
max_len
-
len
(
bbox
[
i
]))
labels
[
i
]
=
labels
[
i
]
+
[
-
100
]
*
(
max_len
-
len
(
labels
[
i
]))
input_ids
[
i
]
=
input_ids
[
i
]
+
[
EOS_TOKEN_ID
]
*
(
max_len
-
len
(
input_ids
[
i
]))
attention_mask
[
i
]
=
attention_mask
[
i
]
+
[
0
]
*
(
max_len
-
len
(
attention_mask
[
i
])
)
ret
=
{
"bbox"
:
torch
.
tensor
(
bbox
),
"attention_mask"
:
torch
.
tensor
(
attention_mask
),
"labels"
:
torch
.
tensor
(
labels
),
"input_ids"
:
torch
.
tensor
(
input_ids
),
}
# set label > MAX_LEN to -100, because original labels may be > MAX_LEN
ret
[
"labels"
][
ret
[
"labels"
]
>
MAX_LEN
]
=
-
100
# set label > 0 to label-1, because original labels are 1-indexed
ret
[
"labels"
][
ret
[
"labels"
]
>
0
]
-=
1
return
ret
def
boxes2inputs
(
boxes
:
List
[
List
[
int
]])
->
Dict
[
str
,
torch
.
Tensor
]:
bbox
=
[[
0
,
0
,
0
,
0
]]
+
boxes
+
[[
0
,
0
,
0
,
0
]]
input_ids
=
[
CLS_TOKEN_ID
]
+
[
UNK_TOKEN_ID
]
*
len
(
boxes
)
+
[
EOS_TOKEN_ID
]
attention_mask
=
[
1
]
+
[
1
]
*
len
(
boxes
)
+
[
1
]
return
{
"bbox"
:
torch
.
tensor
([
bbox
]),
"attention_mask"
:
torch
.
tensor
([
attention_mask
]),
"input_ids"
:
torch
.
tensor
([
input_ids
]),
}
def
prepare_inputs
(
inputs
:
Dict
[
str
,
torch
.
Tensor
],
model
:
LayoutLMv3ForTokenClassification
)
->
Dict
[
str
,
torch
.
Tensor
]:
ret
=
{}
for
k
,
v
in
inputs
.
items
():
v
=
v
.
to
(
model
.
device
)
if
torch
.
is_floating_point
(
v
):
v
=
v
.
to
(
model
.
dtype
)
ret
[
k
]
=
v
return
ret
def
parse_logits
(
logits
:
torch
.
Tensor
,
length
:
int
)
->
List
[
int
]:
"""
parse logits to orders
:param logits: logits from model
:param length: input length
:return: orders
"""
logits
=
logits
[
1
:
length
+
1
,
:
length
]
orders
=
logits
.
argsort
(
descending
=
False
).
tolist
()
ret
=
[
o
.
pop
()
for
o
in
orders
]
while
True
:
order_to_idxes
=
defaultdict
(
list
)
for
idx
,
order
in
enumerate
(
ret
):
order_to_idxes
[
order
].
append
(
idx
)
# filter idxes len > 1
order_to_idxes
=
{
k
:
v
for
k
,
v
in
order_to_idxes
.
items
()
if
len
(
v
)
>
1
}
if
not
order_to_idxes
:
break
# filter
for
order
,
idxes
in
order_to_idxes
.
items
():
# find original logits of idxes
idxes_to_logit
=
{}
for
idx
in
idxes
:
idxes_to_logit
[
idx
]
=
logits
[
idx
,
order
]
idxes_to_logit
=
sorted
(
idxes_to_logit
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
# keep the highest logit as order, set others to next candidate
for
idx
,
_
in
idxes_to_logit
[
1
:]:
ret
[
idx
]
=
orders
[
idx
].
pop
()
return
ret
def
check_duplicate
(
a
:
List
[
int
])
->
bool
:
return
len
(
a
)
!=
len
(
set
(
a
))
magic_pdf/para/para_split_v3.py
0 → 100644
View file @
ece7f8d5
This diff is collapsed.
Click to expand it.
magic_pdf/pdf_parse_by_ocr.py
View file @
ece7f8d5
from
magic_pdf.pdf_parse_union_core
import
pdf_parse_union
from
magic_pdf.pdf_parse_union_core
_v2
import
pdf_parse_union
def
parse_pdf_by_ocr
(
pdf_bytes
,
...
...
magic_pdf/pdf_parse_by_txt.py
View file @
ece7f8d5
from
magic_pdf.pdf_parse_union_core
import
pdf_parse_union
from
magic_pdf.pdf_parse_union_core
_v2
import
pdf_parse_union
def
parse_pdf_by_txt
(
...
...
magic_pdf/pdf_parse_union_core_v2.py
0 → 100644
View file @
ece7f8d5
import
os
import
statistics
import
time
from
loguru
import
logger
from
typing
import
List
import
torch
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.libs.commons
import
fitz
,
get_delta_time
from
magic_pdf.libs.config_reader
import
get_local_layoutreader_model_dir
from
magic_pdf.libs.convert_utils
import
dict_to_list
from
magic_pdf.libs.drop_reason
import
DropReason
from
magic_pdf.libs.hash_utils
import
compute_md5
from
magic_pdf.libs.local_math
import
float_equal
from
magic_pdf.libs.ocr_content_type
import
ContentType
from
magic_pdf.model.magic_model
import
MagicModel
from
magic_pdf.para.para_split_v3
import
para_split
from
magic_pdf.pre_proc.citationmarker_remove
import
remove_citation_marker
from
magic_pdf.pre_proc.construct_page_dict
import
ocr_construct_page_component_v2
from
magic_pdf.pre_proc.cut_image
import
ocr_cut_image_and_table
from
magic_pdf.pre_proc.equations_replace
import
remove_chars_in_text_blocks
,
replace_equations_in_textblock
,
\
combine_chars_to_pymudict
from
magic_pdf.pre_proc.ocr_detect_all_bboxes
import
ocr_prepare_bboxes_for_layout_split_v2
from
magic_pdf.pre_proc.ocr_dict_merge
import
fill_spans_in_blocks
,
fix_block_spans
,
fix_discarded_block
from
magic_pdf.pre_proc.ocr_span_list_modify
import
remove_overlaps_min_spans
,
get_qa_need_list_v2
,
\
remove_overlaps_low_confidence_spans
from
magic_pdf.pre_proc.resolve_bbox_conflict
import
check_useful_block_horizontal_overlap
def
remove_horizontal_overlap_block_which_smaller
(
all_bboxes
):
useful_blocks
=
[]
for
bbox
in
all_bboxes
:
useful_blocks
.
append
({
"bbox"
:
bbox
[:
4
]
})
is_useful_block_horz_overlap
,
smaller_bbox
,
bigger_bbox
=
check_useful_block_horizontal_overlap
(
useful_blocks
)
if
is_useful_block_horz_overlap
:
logger
.
warning
(
f
"skip this page, reason:
{
DropReason
.
USEFUL_BLOCK_HOR_OVERLAP
}
, smaller bbox is
{
smaller_bbox
}
, bigger bbox is
{
bigger_bbox
}
"
)
for
bbox
in
all_bboxes
.
copy
():
if
smaller_bbox
==
bbox
[:
4
]:
all_bboxes
.
remove
(
bbox
)
return
is_useful_block_horz_overlap
,
all_bboxes
def
__replace_STX_ETX
(
text_str
:
str
):
""" Replace
\u0002
and
\u0003
, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
Args:
text_str (str): raw text
Returns:
_type_: replaced text
"""
if
text_str
:
s
=
text_str
.
replace
(
'
\u0002
'
,
"'"
)
s
=
s
.
replace
(
"
\u0003
"
,
"'"
)
return
s
return
text_str
def
txt_spans_extract
(
pdf_page
,
inline_equations
,
interline_equations
):
text_raw_blocks
=
pdf_page
.
get_text
(
"dict"
,
flags
=
fitz
.
TEXTFLAGS_TEXT
)[
"blocks"
]
char_level_text_blocks
=
pdf_page
.
get_text
(
"rawdict"
,
flags
=
fitz
.
TEXTFLAGS_TEXT
)[
"blocks"
]
text_blocks
=
combine_chars_to_pymudict
(
text_raw_blocks
,
char_level_text_blocks
)
text_blocks
=
replace_equations_in_textblock
(
text_blocks
,
inline_equations
,
interline_equations
)
text_blocks
=
remove_citation_marker
(
text_blocks
)
text_blocks
=
remove_chars_in_text_blocks
(
text_blocks
)
spans
=
[]
for
v
in
text_blocks
:
for
line
in
v
[
"lines"
]:
for
span
in
line
[
"spans"
]:
bbox
=
span
[
"bbox"
]
if
float_equal
(
bbox
[
0
],
bbox
[
2
])
or
float_equal
(
bbox
[
1
],
bbox
[
3
]):
continue
if
span
.
get
(
'type'
)
not
in
(
ContentType
.
InlineEquation
,
ContentType
.
InterlineEquation
):
spans
.
append
(
{
"bbox"
:
list
(
span
[
"bbox"
]),
"content"
:
__replace_STX_ETX
(
span
[
"text"
]),
"type"
:
ContentType
.
Text
,
"score"
:
1.0
,
}
)
return
spans
def
replace_text_span
(
pymu_spans
,
ocr_spans
):
return
list
(
filter
(
lambda
x
:
x
[
"type"
]
!=
ContentType
.
Text
,
ocr_spans
))
+
pymu_spans
def
model_init
(
model_name
:
str
):
from
transformers
import
LayoutLMv3ForTokenClassification
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
if
torch
.
cuda
.
is_bf16_supported
():
supports_bfloat16
=
True
else
:
supports_bfloat16
=
False
else
:
device
=
torch
.
device
(
"cpu"
)
supports_bfloat16
=
False
if
model_name
==
"layoutreader"
:
# 检测modelscope的缓存目录是否存在
layoutreader_model_dir
=
get_local_layoutreader_model_dir
()
if
os
.
path
.
exists
(
layoutreader_model_dir
):
model
=
LayoutLMv3ForTokenClassification
.
from_pretrained
(
layoutreader_model_dir
)
else
:
logger
.
warning
(
f
"local layoutreader model not exists, use online model from huggingface"
)
model
=
LayoutLMv3ForTokenClassification
.
from_pretrained
(
"hantian/layoutreader"
)
# 检查设备是否支持 bfloat16
if
supports_bfloat16
:
model
.
bfloat16
()
model
.
to
(
device
).
eval
()
else
:
logger
.
error
(
"model name not allow"
)
exit
(
1
)
return
model
class
ModelSingleton
:
_instance
=
None
_models
=
{}
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
().
__new__
(
cls
)
return
cls
.
_instance
def
get_model
(
self
,
model_name
:
str
):
if
model_name
not
in
self
.
_models
:
self
.
_models
[
model_name
]
=
model_init
(
model_name
=
model_name
)
return
self
.
_models
[
model_name
]
def
do_predict
(
boxes
:
List
[
List
[
int
]],
model
)
->
List
[
int
]:
from
magic_pdf.model.v3.helpers
import
prepare_inputs
,
boxes2inputs
,
parse_logits
inputs
=
boxes2inputs
(
boxes
)
inputs
=
prepare_inputs
(
inputs
,
model
)
logits
=
model
(
**
inputs
).
logits
.
cpu
().
squeeze
(
0
)
return
parse_logits
(
logits
,
len
(
boxes
))
def
cal_block_index
(
fix_blocks
,
sorted_bboxes
):
for
block
in
fix_blocks
:
# if block['type'] in ['text', 'title', 'interline_equation']:
# line_index_list = []
# if len(block['lines']) == 0:
# block['index'] = sorted_bboxes.index(block['bbox'])
# else:
# for line in block['lines']:
# line['index'] = sorted_bboxes.index(line['bbox'])
# line_index_list.append(line['index'])
# median_value = statistics.median(line_index_list)
# block['index'] = median_value
#
# elif block['type'] in ['table', 'image']:
# block['index'] = sorted_bboxes.index(block['bbox'])
line_index_list
=
[]
if
len
(
block
[
'lines'
])
==
0
:
block
[
'index'
]
=
sorted_bboxes
.
index
(
block
[
'bbox'
])
else
:
for
line
in
block
[
'lines'
]:
line
[
'index'
]
=
sorted_bboxes
.
index
(
line
[
'bbox'
])
line_index_list
.
append
(
line
[
'index'
])
median_value
=
statistics
.
median
(
line_index_list
)
block
[
'index'
]
=
median_value
# 删除图表block中的虚拟line信息
if
block
[
'type'
]
in
[
'table'
,
'image'
]:
del
block
[
'lines'
]
return
fix_blocks
def
insert_lines_into_block
(
block_bbox
,
line_height
,
page_w
,
page_h
):
# block_bbox是一个元组(x0, y0, x1, y1),其中(x0, y0)是左下角坐标,(x1, y1)是右上角坐标
x0
,
y0
,
x1
,
y1
=
block_bbox
block_height
=
y1
-
y0
block_weight
=
x1
-
x0
# 如果block高度小于n行正文,则直接返回block的bbox
if
line_height
*
3
<
block_height
:
if
block_height
>
page_h
*
0.25
and
page_w
*
0.5
>
block_weight
>
page_w
*
0.25
:
# 可能是双列结构,可以切细点
lines
=
int
(
block_height
/
line_height
)
+
1
else
:
# 如果block的宽度超过0.4页面宽度,则将block分成3行
if
block_weight
>
page_w
*
0.4
:
line_height
=
(
y1
-
y0
)
/
3
lines
=
3
elif
block_weight
>
page_w
*
0.25
:
# 否则将block分成两行
line_height
=
(
y1
-
y0
)
/
2
lines
=
2
else
:
# 判断长宽比
if
block_height
/
block_weight
>
1.2
:
# 细长的不分
return
[[
x0
,
y0
,
x1
,
y1
]]
else
:
# 不细长的还是分成两行
line_height
=
(
y1
-
y0
)
/
2
lines
=
2
# 确定从哪个y位置开始绘制线条
current_y
=
y0
# 用于存储线条的位置信息[(x0, y), ...]
lines_positions
=
[]
for
i
in
range
(
lines
):
lines_positions
.
append
([
x0
,
current_y
,
x1
,
current_y
+
line_height
])
current_y
+=
line_height
return
lines_positions
else
:
return
[[
x0
,
y0
,
x1
,
y1
]]
def
sort_lines_by_model
(
fix_blocks
,
page_w
,
page_h
,
line_height
):
page_line_list
=
[]
for
block
in
fix_blocks
:
if
block
[
'type'
]
in
[
'text'
,
'title'
,
'interline_equation'
]:
if
len
(
block
[
'lines'
])
==
0
:
bbox
=
block
[
'bbox'
]
lines
=
insert_lines_into_block
(
bbox
,
line_height
,
page_w
,
page_h
)
for
line
in
lines
:
block
[
'lines'
].
append
({
'bbox'
:
line
,
'spans'
:
[]})
page_line_list
.
extend
(
lines
)
else
:
for
line
in
block
[
'lines'
]:
bbox
=
line
[
'bbox'
]
page_line_list
.
append
(
bbox
)
elif
block
[
'type'
]
in
[
'table'
,
'image'
]:
bbox
=
block
[
'bbox'
]
lines
=
insert_lines_into_block
(
bbox
,
line_height
,
page_w
,
page_h
)
block
[
'lines'
]
=
[]
for
line
in
lines
:
block
[
'lines'
].
append
({
'bbox'
:
line
,
'spans'
:
[]})
page_line_list
.
extend
(
lines
)
# 使用layoutreader排序
x_scale
=
1000.0
/
page_w
y_scale
=
1000.0
/
page_h
boxes
=
[]
# logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}")
for
left
,
top
,
right
,
bottom
in
page_line_list
:
if
left
<
0
:
logger
.
warning
(
f
"left < 0, left:
{
left
}
, right:
{
right
}
, top:
{
top
}
, bottom:
{
bottom
}
, page_w:
{
page_w
}
, page_h:
{
page_h
}
"
)
left
=
0
if
right
>
page_w
:
logger
.
warning
(
f
"right > page_w, left:
{
left
}
, right:
{
right
}
, top:
{
top
}
, bottom:
{
bottom
}
, page_w:
{
page_w
}
, page_h:
{
page_h
}
"
)
right
=
page_w
if
top
<
0
:
logger
.
warning
(
f
"top < 0, left:
{
left
}
, right:
{
right
}
, top:
{
top
}
, bottom:
{
bottom
}
, page_w:
{
page_w
}
, page_h:
{
page_h
}
"
)
top
=
0
if
bottom
>
page_h
:
logger
.
warning
(
f
"bottom > page_h, left:
{
left
}
, right:
{
right
}
, top:
{
top
}
, bottom:
{
bottom
}
, page_w:
{
page_w
}
, page_h:
{
page_h
}
"
)
bottom
=
page_h
left
=
round
(
left
*
x_scale
)
top
=
round
(
top
*
y_scale
)
right
=
round
(
right
*
x_scale
)
bottom
=
round
(
bottom
*
y_scale
)
assert
(
1000
>=
right
>=
left
>=
0
and
1000
>=
bottom
>=
top
>=
0
),
f
"Invalid box. right:
{
right
}
, left:
{
left
}
, bottom:
{
bottom
}
, top:
{
top
}
"
boxes
.
append
([
left
,
top
,
right
,
bottom
])
model_manager
=
ModelSingleton
()
model
=
model_manager
.
get_model
(
"layoutreader"
)
with
torch
.
no_grad
():
orders
=
do_predict
(
boxes
,
model
)
sorted_bboxes
=
[
page_line_list
[
i
]
for
i
in
orders
]
return
sorted_bboxes
def
get_line_height
(
blocks
):
page_line_height_list
=
[]
for
block
in
blocks
:
if
block
[
'type'
]
in
[
'text'
,
'title'
,
'interline_equation'
]:
for
line
in
block
[
'lines'
]:
bbox
=
line
[
'bbox'
]
page_line_height_list
.
append
(
int
(
bbox
[
3
]
-
bbox
[
1
]))
if
len
(
page_line_height_list
)
>
0
:
return
statistics
.
median
(
page_line_height_list
)
else
:
return
10
def
parse_page_core
(
pdf_docs
,
magic_model
,
page_id
,
pdf_bytes_md5
,
imageWriter
,
parse_mode
):
need_drop
=
False
drop_reason
=
[]
'''从magic_model对象中获取后面会用到的区块信息'''
img_blocks
=
magic_model
.
get_imgs
(
page_id
)
table_blocks
=
magic_model
.
get_tables
(
page_id
)
discarded_blocks
=
magic_model
.
get_discarded
(
page_id
)
text_blocks
=
magic_model
.
get_text_blocks
(
page_id
)
title_blocks
=
magic_model
.
get_title_blocks
(
page_id
)
inline_equations
,
interline_equations
,
interline_equation_blocks
=
magic_model
.
get_equations
(
page_id
)
page_w
,
page_h
=
magic_model
.
get_page_size
(
page_id
)
spans
=
magic_model
.
get_all_spans
(
page_id
)
'''根据parse_mode,构造spans'''
if
parse_mode
==
"txt"
:
"""ocr 中文本类的 span 用 pymu spans 替换!"""
pymu_spans
=
txt_spans_extract
(
pdf_docs
[
page_id
],
inline_equations
,
interline_equations
)
spans
=
replace_text_span
(
pymu_spans
,
spans
)
elif
parse_mode
==
"ocr"
:
pass
else
:
raise
Exception
(
"parse_mode must be txt or ocr"
)
'''删除重叠spans中置信度较低的那些'''
spans
,
dropped_spans_by_confidence
=
remove_overlaps_low_confidence_spans
(
spans
)
'''删除重叠spans中较小的那些'''
spans
,
dropped_spans_by_span_overlap
=
remove_overlaps_min_spans
(
spans
)
'''对image和table截图'''
spans
=
ocr_cut_image_and_table
(
spans
,
pdf_docs
[
page_id
],
page_id
,
pdf_bytes_md5
,
imageWriter
)
'''将所有区块的bbox整理到一起'''
# interline_equation_blocks参数不够准,后面切换到interline_equations上
interline_equation_blocks
=
[]
if
len
(
interline_equation_blocks
)
>
0
:
all_bboxes
,
all_discarded_blocks
=
ocr_prepare_bboxes_for_layout_split_v2
(
img_blocks
,
table_blocks
,
discarded_blocks
,
text_blocks
,
title_blocks
,
interline_equation_blocks
,
page_w
,
page_h
)
else
:
all_bboxes
,
all_discarded_blocks
=
ocr_prepare_bboxes_for_layout_split_v2
(
img_blocks
,
table_blocks
,
discarded_blocks
,
text_blocks
,
title_blocks
,
interline_equations
,
page_w
,
page_h
)
'''先处理不需要排版的discarded_blocks'''
discarded_block_with_spans
,
spans
=
fill_spans_in_blocks
(
all_discarded_blocks
,
spans
,
0.4
)
fix_discarded_blocks
=
fix_discarded_block
(
discarded_block_with_spans
)
'''如果当前页面没有bbox则跳过'''
if
len
(
all_bboxes
)
==
0
:
logger
.
warning
(
f
"skip this page, not found useful bbox, page_id:
{
page_id
}
"
)
return
ocr_construct_page_component_v2
([],
[],
page_id
,
page_w
,
page_h
,
[],
[],
[],
interline_equations
,
fix_discarded_blocks
,
need_drop
,
drop_reason
)
'''将span填入blocks中'''
block_with_spans
,
spans
=
fill_spans_in_blocks
(
all_bboxes
,
spans
,
0.3
)
'''对block进行fix操作'''
fix_blocks
=
fix_block_spans
(
block_with_spans
,
img_blocks
,
table_blocks
)
'''获取所有line并计算正文line的高度'''
line_height
=
get_line_height
(
fix_blocks
)
'''获取所有line并对line排序'''
sorted_bboxes
=
sort_lines_by_model
(
fix_blocks
,
page_w
,
page_h
,
line_height
)
'''根据line的中位数算block的序列关系'''
fix_blocks
=
cal_block_index
(
fix_blocks
,
sorted_bboxes
)
'''重排block'''
sorted_blocks
=
sorted
(
fix_blocks
,
key
=
lambda
b
:
b
[
'index'
])
'''获取QA需要外置的list'''
images
,
tables
,
interline_equations
=
get_qa_need_list_v2
(
sorted_blocks
)
'''构造pdf_info_dict'''
page_info
=
ocr_construct_page_component_v2
(
sorted_blocks
,
[],
page_id
,
page_w
,
page_h
,
[],
images
,
tables
,
interline_equations
,
fix_discarded_blocks
,
need_drop
,
drop_reason
)
return
page_info
def
pdf_parse_union
(
pdf_bytes
,
model_list
,
imageWriter
,
parse_mode
,
start_page_id
=
0
,
end_page_id
=
None
,
debug_mode
=
False
,
):
pdf_bytes_md5
=
compute_md5
(
pdf_bytes
)
pdf_docs
=
fitz
.
open
(
"pdf"
,
pdf_bytes
)
'''初始化空的pdf_info_dict'''
pdf_info_dict
=
{}
'''用model_list和docs对象初始化magic_model'''
magic_model
=
MagicModel
(
model_list
,
pdf_docs
)
'''根据输入的起始范围解析pdf'''
# end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
end_page_id
=
end_page_id
if
end_page_id
is
not
None
and
end_page_id
>=
0
else
len
(
pdf_docs
)
-
1
if
end_page_id
>
len
(
pdf_docs
)
-
1
:
logger
.
warning
(
"end_page_id is out of range, use pdf_docs length"
)
end_page_id
=
len
(
pdf_docs
)
-
1
'''初始化启动时间'''
start_time
=
time
.
time
()
for
page_id
,
page
in
enumerate
(
pdf_docs
):
'''debug时输出每页解析的耗时'''
if
debug_mode
:
time_now
=
time
.
time
()
logger
.
info
(
f
"page_id:
{
page_id
}
, last_page_cost_time:
{
get_delta_time
(
start_time
)
}
"
)
start_time
=
time_now
'''解析pdf中的每一页'''
if
start_page_id
<=
page_id
<=
end_page_id
:
page_info
=
parse_page_core
(
pdf_docs
,
magic_model
,
page_id
,
pdf_bytes_md5
,
imageWriter
,
parse_mode
)
else
:
page_w
=
page
.
rect
.
width
page_h
=
page
.
rect
.
height
page_info
=
ocr_construct_page_component_v2
([],
[],
page_id
,
page_w
,
page_h
,
[],
[],
[],
[],
[],
True
,
"skip page"
)
pdf_info_dict
[
f
"page_
{
page_id
}
"
]
=
page_info
"""分段"""
para_split
(
pdf_info_dict
,
debug_mode
=
debug_mode
)
"""dict转list"""
pdf_info_list
=
dict_to_list
(
pdf_info_dict
)
new_pdf_info_dict
=
{
"pdf_info"
:
pdf_info_list
,
}
clean_memory
()
return
new_pdf_info_dict
if
__name__
==
'__main__'
:
pass
magic_pdf/pipe/AbsPipe.py
View file @
ece7f8d5
...
...
@@ -17,7 +17,7 @@ class AbsPipe(ABC):
PIP_TXT
=
"txt"
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
):
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
self
.
pdf_bytes
=
pdf_bytes
self
.
model_list
=
model_list
self
.
image_writer
=
image_writer
...
...
@@ -25,6 +25,7 @@ class AbsPipe(ABC):
self
.
is_debug
=
is_debug
self
.
start_page_id
=
start_page_id
self
.
end_page_id
=
end_page_id
self
.
lang
=
lang
def
get_compress_pdf_mid_data
(
self
):
return
JsonCompressor
.
compress_json
(
self
.
pdf_mid_data
)
...
...
@@ -94,7 +95,9 @@ class AbsPipe(ABC):
"""
pdf_mid_data
=
JsonCompressor
.
decompress_json
(
compressed_pdf_mid_data
)
pdf_info_list
=
pdf_mid_data
[
"pdf_info"
]
content_list
=
union_make
(
pdf_info_list
,
MakeMode
.
STANDARD_FORMAT
,
drop_mode
,
img_buket_path
)
parse_type
=
pdf_mid_data
[
"_parse_type"
]
lang
=
pdf_mid_data
.
get
(
"_lang"
,
None
)
content_list
=
union_make
(
pdf_info_list
,
MakeMode
.
STANDARD_FORMAT
,
drop_mode
,
img_buket_path
,
parse_type
,
lang
)
return
content_list
@
staticmethod
...
...
@@ -104,7 +107,9 @@ class AbsPipe(ABC):
"""
pdf_mid_data
=
JsonCompressor
.
decompress_json
(
compressed_pdf_mid_data
)
pdf_info_list
=
pdf_mid_data
[
"pdf_info"
]
md_content
=
union_make
(
pdf_info_list
,
md_make_mode
,
drop_mode
,
img_buket_path
)
parse_type
=
pdf_mid_data
[
"_parse_type"
]
lang
=
pdf_mid_data
.
get
(
"_lang"
,
None
)
md_content
=
union_make
(
pdf_info_list
,
md_make_mode
,
drop_mode
,
img_buket_path
,
parse_type
,
lang
)
return
md_content
magic_pdf/pipe/OCRPipe.py
View file @
ece7f8d5
...
...
@@ -10,19 +10,21 @@ from magic_pdf.user_api import parse_ocr_pdf
class
OCRPipe
(
AbsPipe
):
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
):
super
().
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
)
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
super
().
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
)
def
pipe_classify
(
self
):
pass
def
pipe_analyze
(
self
):
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
True
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
)
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
def
pipe_parse
(
self
):
self
.
pdf_mid_data
=
parse_ocr_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
)
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
result
=
super
().
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
...
...
magic_pdf/pipe/TXTPipe.py
View file @
ece7f8d5
...
...
@@ -11,19 +11,21 @@ from magic_pdf.user_api import parse_txt_pdf
class
TXTPipe
(
AbsPipe
):
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
):
super
().
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
)
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
super
().
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
)
def
pipe_classify
(
self
):
pass
def
pipe_analyze
(
self
):
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
False
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
)
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
def
pipe_parse
(
self
):
self
.
pdf_mid_data
=
parse_txt_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
)
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
result
=
super
().
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
...
...
magic_pdf/pipe/UNIPipe.py
View file @
ece7f8d5
...
...
@@ -14,9 +14,9 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
class
UNIPipe
(
AbsPipe
):
def
__init__
(
self
,
pdf_bytes
:
bytes
,
jso_useful_key
:
dict
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
):
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
self
.
pdf_type
=
jso_useful_key
[
"_pdf_type"
]
super
().
__init__
(
pdf_bytes
,
jso_useful_key
[
"model_list"
],
image_writer
,
is_debug
,
start_page_id
,
end_page_id
)
super
().
__init__
(
pdf_bytes
,
jso_useful_key
[
"model_list"
],
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
)
if
len
(
self
.
model_list
)
==
0
:
self
.
input_model_is_empty
=
True
else
:
...
...
@@ -28,22 +28,26 @@ class UNIPipe(AbsPipe):
def
pipe_analyze
(
self
):
if
self
.
pdf_type
==
self
.
PIP_TXT
:
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
False
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
)
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
elif
self
.
pdf_type
==
self
.
PIP_OCR
:
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
True
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
)
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
def
pipe_parse
(
self
):
if
self
.
pdf_type
==
self
.
PIP_TXT
:
self
.
pdf_mid_data
=
parse_union_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
input_model_is_empty
=
self
.
input_model_is_empty
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
)
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
elif
self
.
pdf_type
==
self
.
PIP_OCR
:
self
.
pdf_mid_data
=
parse_ocr_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
)
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
NONE_WITH_REASON
):
result
=
super
().
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
logger
.
info
(
"uni_pipe mk content list finished"
)
return
result
...
...
magic_pdf/pre_proc/ocr_detect_all_bboxes.py
View file @
ece7f8d5
...
...
@@ -60,6 +60,59 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
return
all_bboxes
,
all_discarded_blocks
,
drop_reasons
def
ocr_prepare_bboxes_for_layout_split_v2
(
img_blocks
,
table_blocks
,
discarded_blocks
,
text_blocks
,
title_blocks
,
interline_equation_blocks
,
page_w
,
page_h
):
all_bboxes
=
[]
all_discarded_blocks
=
[]
for
image
in
img_blocks
:
x0
,
y0
,
x1
,
y1
=
image
[
'bbox'
]
all_bboxes
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
BlockType
.
Image
,
None
,
None
,
None
,
None
,
image
[
"score"
]])
for
table
in
table_blocks
:
x0
,
y0
,
x1
,
y1
=
table
[
'bbox'
]
all_bboxes
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
BlockType
.
Table
,
None
,
None
,
None
,
None
,
table
[
"score"
]])
for
text
in
text_blocks
:
x0
,
y0
,
x1
,
y1
=
text
[
'bbox'
]
all_bboxes
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
BlockType
.
Text
,
None
,
None
,
None
,
None
,
text
[
"score"
]])
for
title
in
title_blocks
:
x0
,
y0
,
x1
,
y1
=
title
[
'bbox'
]
all_bboxes
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
BlockType
.
Title
,
None
,
None
,
None
,
None
,
title
[
"score"
]])
for
interline_equation
in
interline_equation_blocks
:
x0
,
y0
,
x1
,
y1
=
interline_equation
[
'bbox'
]
all_bboxes
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
BlockType
.
InterlineEquation
,
None
,
None
,
None
,
None
,
interline_equation
[
"score"
]])
'''block嵌套问题解决'''
'''文本框与标题框重叠,优先信任文本框'''
all_bboxes
=
fix_text_overlap_title_blocks
(
all_bboxes
)
'''任何框体与舍弃框重叠,优先信任舍弃框'''
all_bboxes
=
remove_need_drop_blocks
(
all_bboxes
,
discarded_blocks
)
# interline_equation 与title或text框冲突的情况,分两种情况处理
'''interline_equation框与文本类型框iou比较接近1的时候,信任行间公式框'''
all_bboxes
=
fix_interline_equation_overlap_text_blocks_with_hi_iou
(
all_bboxes
)
'''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
# 通过后续大框套小框逻辑删除
'''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)'''
for
discarded
in
discarded_blocks
:
x0
,
y0
,
x1
,
y1
=
discarded
[
'bbox'
]
all_discarded_blocks
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
BlockType
.
Discarded
,
None
,
None
,
None
,
None
,
discarded
[
"score"
]])
# 将footnote加入到all_bboxes中,用来计算layout
# if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
# all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Footnote, None, None, None, None, discarded["score"]])
'''经过以上处理后,还存在大框套小框的情况,则删除小框'''
all_bboxes
=
remove_overlaps_min_blocks
(
all_bboxes
)
all_discarded_blocks
=
remove_overlaps_min_blocks
(
all_discarded_blocks
)
'''将剩余的bbox做分离处理,防止后面分layout时出错'''
all_bboxes
,
drop_reasons
=
remove_overlap_between_bbox_for_block
(
all_bboxes
)
return
all_bboxes
,
all_discarded_blocks
def
fix_interline_equation_overlap_text_blocks_with_hi_iou
(
all_bboxes
):
# 先提取所有text和interline block
text_blocks
=
[]
...
...
magic_pdf/pre_proc/ocr_dict_merge.py
View file @
ece7f8d5
...
...
@@ -49,8 +49,7 @@ def merge_spans_to_line(spans):
continue
# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if
__is_overlaps_y_exceeds_threshold
(
span
[
'bbox'
],
current_line
[
-
1
][
'bbox'
]):
if
__is_overlaps_y_exceeds_threshold
(
span
[
'bbox'
],
current_line
[
-
1
][
'bbox'
],
0.6
):
current_line
.
append
(
span
)
else
:
# 否则,开始新行
...
...
magic_pdf/resources/model_config/UniMERNet/demo.yaml
View file @
ece7f8d5
...
...
@@ -2,13 +2,13 @@ model:
arch
:
unimernet
model_type
:
unimernet
model_config
:
model_name
:
./models
max_seq_len
:
1
024
length_aware
:
False
model_name
:
./models
/unimernet_base
max_seq_len
:
1
536
load_pretrained
:
True
pretrained
:
./models/pytorch_model.
bin
pretrained
:
'
./models/
unimernet_base/
pytorch_model.
pth'
tokenizer_config
:
path
:
./models
path
:
./models
/unimernet_base
datasets
:
formula_rec_eval
:
...
...
@@ -18,7 +18,7 @@ datasets:
image_size
:
-
192
-
672
run
:
runner
:
runner_iter
task
:
unimernet_train
...
...
@@ -43,4 +43,4 @@ run:
distributed_type
:
ddp
# or fsdp when train llm
generate_cfg
:
temperature
:
0.0
temperature
:
0.0
\ No newline at end of file
magic_pdf/resources/model_config/model_configs.yaml
View file @
ece7f8d5
...
...
@@ -10,6 +10,6 @@ config:
weights
:
layout
:
Layout/model_final.pth
mfd
:
MFD/weights.pt
mfr
:
MFR/
U
ni
MERNet
mfr
:
MFR/
u
ni
mernet_small
struct_eqtable
:
TabRec/StructEqTable
TableMaster
:
TabRec/TableMaster
\ No newline at end of file
Prev
1
2
3
4
5
6
7
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment