Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
bd927919
Commit
bd927919
authored
May 27, 2025
by
myhloli
Browse files
refactor: rename init file and update app.py to enable parsing method
parent
f5016508
Changes
150
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
2429 deletions
+0
-2429
magic_pdf/libs/convert_utils.py
magic_pdf/libs/convert_utils.py
+0
-5
magic_pdf/libs/coordinate_transform.py
magic_pdf/libs/coordinate_transform.py
+0
-9
magic_pdf/libs/draw_bbox.py
magic_pdf/libs/draw_bbox.py
+0
-418
magic_pdf/libs/hash_utils.py
magic_pdf/libs/hash_utils.py
+0
-15
magic_pdf/libs/json_compressor.py
magic_pdf/libs/json_compressor.py
+0
-27
magic_pdf/libs/language.py
magic_pdf/libs/language.py
+0
-48
magic_pdf/libs/local_math.py
magic_pdf/libs/local_math.py
+0
-9
magic_pdf/libs/markdown_utils.py
magic_pdf/libs/markdown_utils.py
+0
-10
magic_pdf/libs/path_utils.py
magic_pdf/libs/path_utils.py
+0
-32
magic_pdf/libs/pdf_check.py
magic_pdf/libs/pdf_check.py
+0
-99
magic_pdf/libs/pdf_image_tools.py
magic_pdf/libs/pdf_image_tools.py
+0
-63
magic_pdf/libs/performance_stats.py
magic_pdf/libs/performance_stats.py
+0
-65
magic_pdf/libs/safe_filename.py
magic_pdf/libs/safe_filename.py
+0
-11
magic_pdf/libs/version.py
magic_pdf/libs/version.py
+0
-1
magic_pdf/model/__init__.py
magic_pdf/model/__init__.py
+0
-2
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+0
-265
magic_pdf/model/doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+0
-301
magic_pdf/model/magic_model.py
magic_pdf/model/magic_model.py
+0
-771
magic_pdf/model/model_list.py
magic_pdf/model/model_list.py
+0
-12
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+0
-266
No files found.
Too many changes to show.
To preserve performance only
150 of 150+
files are displayed.
Plain diff
Email patch
magic_pdf/libs/convert_utils.py
deleted
100644 → 0
View file @
f5016508
def
dict_to_list
(
input_dict
):
items_list
=
[]
for
_
,
item
in
input_dict
.
items
():
items_list
.
append
(
item
)
return
items_list
magic_pdf/libs/coordinate_transform.py
deleted
100644 → 0
View file @
f5016508
def
get_scale_ratio
(
model_page_info
,
page
):
pix
=
page
.
get_pixmap
(
dpi
=
72
)
pymu_width
=
int
(
pix
.
w
)
pymu_height
=
int
(
pix
.
h
)
width_from_json
=
model_page_info
[
'page_info'
][
'width'
]
height_from_json
=
model_page_info
[
'page_info'
][
'height'
]
horizontal_scale_ratio
=
width_from_json
/
pymu_width
vertical_scale_ratio
=
height_from_json
/
pymu_height
return
horizontal_scale_ratio
,
vertical_scale_ratio
magic_pdf/libs/draw_bbox.py
deleted
100644 → 0
View file @
f5016508
import
fitz
from
magic_pdf.config.constants
import
CROSS_PAGE
from
magic_pdf.config.ocr_content_type
import
(
BlockType
,
CategoryId
,
ContentType
)
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.model.magic_model
import
MagicModel
def
draw_bbox_without_number
(
i
,
bbox_list
,
page
,
rgb_config
,
fill_config
):
new_rgb
=
[]
for
item
in
rgb_config
:
item
=
float
(
item
)
/
255
new_rgb
.
append
(
item
)
page_data
=
bbox_list
[
i
]
for
bbox
in
page_data
:
x0
,
y0
,
x1
,
y1
=
bbox
rect_coords
=
fitz
.
Rect
(
x0
,
y0
,
x1
,
y1
)
# Define the rectangle
if
fill_config
:
page
.
draw_rect
(
rect_coords
,
color
=
None
,
fill
=
new_rgb
,
fill_opacity
=
0.3
,
width
=
0.5
,
overlay
=
True
,
)
# Draw the rectangle
else
:
page
.
draw_rect
(
rect_coords
,
color
=
new_rgb
,
fill
=
None
,
fill_opacity
=
1
,
width
=
0.5
,
overlay
=
True
,
)
# Draw the rectangle
def
draw_bbox_with_number
(
i
,
bbox_list
,
page
,
rgb_config
,
fill_config
,
draw_bbox
=
True
):
new_rgb
=
[]
for
item
in
rgb_config
:
item
=
float
(
item
)
/
255
new_rgb
.
append
(
item
)
page_data
=
bbox_list
[
i
]
for
j
,
bbox
in
enumerate
(
page_data
):
x0
,
y0
,
x1
,
y1
=
bbox
rect_coords
=
fitz
.
Rect
(
x0
,
y0
,
x1
,
y1
)
# Define the rectangle
if
draw_bbox
:
if
fill_config
:
page
.
draw_rect
(
rect_coords
,
color
=
None
,
fill
=
new_rgb
,
fill_opacity
=
0.3
,
width
=
0.5
,
overlay
=
True
,
)
# Draw the rectangle
else
:
page
.
draw_rect
(
rect_coords
,
color
=
new_rgb
,
fill
=
None
,
fill_opacity
=
1
,
width
=
0.5
,
overlay
=
True
,
)
# Draw the rectangle
page
.
insert_text
(
(
x1
+
2
,
y0
+
10
),
str
(
j
+
1
),
fontsize
=
10
,
color
=
new_rgb
)
# Insert the index in the top left corner of the rectangle
def
draw_layout_bbox
(
pdf_info
,
pdf_bytes
,
out_path
,
filename
):
dropped_bbox_list
=
[]
tables_list
,
tables_body_list
=
[],
[]
tables_caption_list
,
tables_footnote_list
=
[],
[]
imgs_list
,
imgs_body_list
,
imgs_caption_list
=
[],
[],
[]
imgs_footnote_list
=
[]
titles_list
=
[]
texts_list
=
[]
interequations_list
=
[]
lists_list
=
[]
indexs_list
=
[]
for
page
in
pdf_info
:
page_dropped_list
=
[]
tables
,
tables_body
,
tables_caption
,
tables_footnote
=
[],
[],
[],
[]
imgs
,
imgs_body
,
imgs_caption
,
imgs_footnote
=
[],
[],
[],
[]
titles
=
[]
texts
=
[]
interequations
=
[]
lists
=
[]
indices
=
[]
for
dropped_bbox
in
page
[
'discarded_blocks'
]:
page_dropped_list
.
append
(
dropped_bbox
[
'bbox'
])
dropped_bbox_list
.
append
(
page_dropped_list
)
for
block
in
page
[
'para_blocks'
]:
bbox
=
block
[
'bbox'
]
if
block
[
'type'
]
==
BlockType
.
Table
:
tables
.
append
(
bbox
)
for
nested_block
in
block
[
'blocks'
]:
bbox
=
nested_block
[
'bbox'
]
if
nested_block
[
'type'
]
==
BlockType
.
TableBody
:
tables_body
.
append
(
bbox
)
elif
nested_block
[
'type'
]
==
BlockType
.
TableCaption
:
tables_caption
.
append
(
bbox
)
elif
nested_block
[
'type'
]
==
BlockType
.
TableFootnote
:
tables_footnote
.
append
(
bbox
)
elif
block
[
'type'
]
==
BlockType
.
Image
:
imgs
.
append
(
bbox
)
for
nested_block
in
block
[
'blocks'
]:
bbox
=
nested_block
[
'bbox'
]
if
nested_block
[
'type'
]
==
BlockType
.
ImageBody
:
imgs_body
.
append
(
bbox
)
elif
nested_block
[
'type'
]
==
BlockType
.
ImageCaption
:
imgs_caption
.
append
(
bbox
)
elif
nested_block
[
'type'
]
==
BlockType
.
ImageFootnote
:
imgs_footnote
.
append
(
bbox
)
elif
block
[
'type'
]
==
BlockType
.
Title
:
titles
.
append
(
bbox
)
elif
block
[
'type'
]
==
BlockType
.
Text
:
texts
.
append
(
bbox
)
elif
block
[
'type'
]
==
BlockType
.
InterlineEquation
:
interequations
.
append
(
bbox
)
elif
block
[
'type'
]
==
BlockType
.
List
:
lists
.
append
(
bbox
)
elif
block
[
'type'
]
==
BlockType
.
Index
:
indices
.
append
(
bbox
)
tables_list
.
append
(
tables
)
tables_body_list
.
append
(
tables_body
)
tables_caption_list
.
append
(
tables_caption
)
tables_footnote_list
.
append
(
tables_footnote
)
imgs_list
.
append
(
imgs
)
imgs_body_list
.
append
(
imgs_body
)
imgs_caption_list
.
append
(
imgs_caption
)
imgs_footnote_list
.
append
(
imgs_footnote
)
titles_list
.
append
(
titles
)
texts_list
.
append
(
texts
)
interequations_list
.
append
(
interequations
)
lists_list
.
append
(
lists
)
indexs_list
.
append
(
indices
)
layout_bbox_list
=
[]
table_type_order
=
{
'table_caption'
:
1
,
'table_body'
:
2
,
'table_footnote'
:
3
}
for
page
in
pdf_info
:
page_block_list
=
[]
for
block
in
page
[
'para_blocks'
]:
if
block
[
'type'
]
in
[
BlockType
.
Text
,
BlockType
.
Title
,
BlockType
.
InterlineEquation
,
BlockType
.
List
,
BlockType
.
Index
,
]:
bbox
=
block
[
'bbox'
]
page_block_list
.
append
(
bbox
)
elif
block
[
'type'
]
in
[
BlockType
.
Image
]:
for
sub_block
in
block
[
'blocks'
]:
bbox
=
sub_block
[
'bbox'
]
page_block_list
.
append
(
bbox
)
elif
block
[
'type'
]
in
[
BlockType
.
Table
]:
sorted_blocks
=
sorted
(
block
[
'blocks'
],
key
=
lambda
x
:
table_type_order
[
x
[
'type'
]])
for
sub_block
in
sorted_blocks
:
bbox
=
sub_block
[
'bbox'
]
page_block_list
.
append
(
bbox
)
layout_bbox_list
.
append
(
page_block_list
)
pdf_docs
=
fitz
.
open
(
'pdf'
,
pdf_bytes
)
for
i
,
page
in
enumerate
(
pdf_docs
):
draw_bbox_without_number
(
i
,
dropped_bbox_list
,
page
,
[
158
,
158
,
158
],
True
)
# draw_bbox_without_number(i, tables_list, page, [153, 153, 0], True) # color !
draw_bbox_without_number
(
i
,
tables_body_list
,
page
,
[
204
,
204
,
0
],
True
)
draw_bbox_without_number
(
i
,
tables_caption_list
,
page
,
[
255
,
255
,
102
],
True
)
draw_bbox_without_number
(
i
,
tables_footnote_list
,
page
,
[
229
,
255
,
204
],
True
)
# draw_bbox_without_number(i, imgs_list, page, [51, 102, 0], True)
draw_bbox_without_number
(
i
,
imgs_body_list
,
page
,
[
153
,
255
,
51
],
True
)
draw_bbox_without_number
(
i
,
imgs_caption_list
,
page
,
[
102
,
178
,
255
],
True
)
draw_bbox_without_number
(
i
,
imgs_footnote_list
,
page
,
[
255
,
178
,
102
],
True
),
draw_bbox_without_number
(
i
,
titles_list
,
page
,
[
102
,
102
,
255
],
True
)
draw_bbox_without_number
(
i
,
texts_list
,
page
,
[
153
,
0
,
76
],
True
)
draw_bbox_without_number
(
i
,
interequations_list
,
page
,
[
0
,
255
,
0
],
True
)
draw_bbox_without_number
(
i
,
lists_list
,
page
,
[
40
,
169
,
92
],
True
)
draw_bbox_without_number
(
i
,
indexs_list
,
page
,
[
40
,
169
,
92
],
True
)
draw_bbox_with_number
(
i
,
layout_bbox_list
,
page
,
[
255
,
0
,
0
],
False
,
draw_bbox
=
False
)
# Save the PDF
pdf_docs
.
save
(
f
'
{
out_path
}
/
{
filename
}
'
)
def
draw_span_bbox
(
pdf_info
,
pdf_bytes
,
out_path
,
filename
):
text_list
=
[]
inline_equation_list
=
[]
interline_equation_list
=
[]
image_list
=
[]
table_list
=
[]
dropped_list
=
[]
next_page_text_list
=
[]
next_page_inline_equation_list
=
[]
def
get_span_info
(
span
):
if
span
[
'type'
]
==
ContentType
.
Text
:
if
span
.
get
(
CROSS_PAGE
,
False
):
next_page_text_list
.
append
(
span
[
'bbox'
])
else
:
page_text_list
.
append
(
span
[
'bbox'
])
elif
span
[
'type'
]
==
ContentType
.
InlineEquation
:
if
span
.
get
(
CROSS_PAGE
,
False
):
next_page_inline_equation_list
.
append
(
span
[
'bbox'
])
else
:
page_inline_equation_list
.
append
(
span
[
'bbox'
])
elif
span
[
'type'
]
==
ContentType
.
InterlineEquation
:
page_interline_equation_list
.
append
(
span
[
'bbox'
])
elif
span
[
'type'
]
==
ContentType
.
Image
:
page_image_list
.
append
(
span
[
'bbox'
])
elif
span
[
'type'
]
==
ContentType
.
Table
:
page_table_list
.
append
(
span
[
'bbox'
])
for
page
in
pdf_info
:
page_text_list
=
[]
page_inline_equation_list
=
[]
page_interline_equation_list
=
[]
page_image_list
=
[]
page_table_list
=
[]
page_dropped_list
=
[]
# 将跨页的span放到移动到下一页的列表中
if
len
(
next_page_text_list
)
>
0
:
page_text_list
.
extend
(
next_page_text_list
)
next_page_text_list
.
clear
()
if
len
(
next_page_inline_equation_list
)
>
0
:
page_inline_equation_list
.
extend
(
next_page_inline_equation_list
)
next_page_inline_equation_list
.
clear
()
# 构造dropped_list
for
block
in
page
[
'discarded_blocks'
]:
if
block
[
'type'
]
==
BlockType
.
Discarded
:
for
line
in
block
[
'lines'
]:
for
span
in
line
[
'spans'
]:
page_dropped_list
.
append
(
span
[
'bbox'
])
dropped_list
.
append
(
page_dropped_list
)
# 构造其余useful_list
# for block in page['para_blocks']: # span直接用分段合并前的结果就可以
for
block
in
page
[
'preproc_blocks'
]:
if
block
[
'type'
]
in
[
BlockType
.
Text
,
BlockType
.
Title
,
BlockType
.
InterlineEquation
,
BlockType
.
List
,
BlockType
.
Index
,
]:
for
line
in
block
[
'lines'
]:
for
span
in
line
[
'spans'
]:
get_span_info
(
span
)
elif
block
[
'type'
]
in
[
BlockType
.
Image
,
BlockType
.
Table
]:
for
sub_block
in
block
[
'blocks'
]:
for
line
in
sub_block
[
'lines'
]:
for
span
in
line
[
'spans'
]:
get_span_info
(
span
)
text_list
.
append
(
page_text_list
)
inline_equation_list
.
append
(
page_inline_equation_list
)
interline_equation_list
.
append
(
page_interline_equation_list
)
image_list
.
append
(
page_image_list
)
table_list
.
append
(
page_table_list
)
pdf_docs
=
fitz
.
open
(
'pdf'
,
pdf_bytes
)
for
i
,
page
in
enumerate
(
pdf_docs
):
# 获取当前页面的数据
draw_bbox_without_number
(
i
,
text_list
,
page
,
[
255
,
0
,
0
],
False
)
draw_bbox_without_number
(
i
,
inline_equation_list
,
page
,
[
0
,
255
,
0
],
False
)
draw_bbox_without_number
(
i
,
interline_equation_list
,
page
,
[
0
,
0
,
255
],
False
)
draw_bbox_without_number
(
i
,
image_list
,
page
,
[
255
,
204
,
0
],
False
)
draw_bbox_without_number
(
i
,
table_list
,
page
,
[
204
,
0
,
255
],
False
)
draw_bbox_without_number
(
i
,
dropped_list
,
page
,
[
158
,
158
,
158
],
False
)
# Save the PDF
pdf_docs
.
save
(
f
'
{
out_path
}
/
{
filename
}
'
)
def
draw_model_bbox
(
model_list
,
dataset
:
Dataset
,
out_path
,
filename
):
dropped_bbox_list
=
[]
tables_body_list
,
tables_caption_list
,
tables_footnote_list
=
[],
[],
[]
imgs_body_list
,
imgs_caption_list
,
imgs_footnote_list
=
[],
[],
[]
titles_list
=
[]
texts_list
=
[]
interequations_list
=
[]
magic_model
=
MagicModel
(
model_list
,
dataset
)
for
i
in
range
(
len
(
model_list
)):
page_dropped_list
=
[]
tables_body
,
tables_caption
,
tables_footnote
=
[],
[],
[]
imgs_body
,
imgs_caption
,
imgs_footnote
=
[],
[],
[]
titles
=
[]
texts
=
[]
interequations
=
[]
page_info
=
magic_model
.
get_model_list
(
i
)
layout_dets
=
page_info
[
'layout_dets'
]
for
layout_det
in
layout_dets
:
bbox
=
layout_det
[
'bbox'
]
if
layout_det
[
'category_id'
]
==
CategoryId
.
Text
:
texts
.
append
(
bbox
)
elif
layout_det
[
'category_id'
]
==
CategoryId
.
Title
:
titles
.
append
(
bbox
)
elif
layout_det
[
'category_id'
]
==
CategoryId
.
TableBody
:
tables_body
.
append
(
bbox
)
elif
layout_det
[
'category_id'
]
==
CategoryId
.
TableCaption
:
tables_caption
.
append
(
bbox
)
elif
layout_det
[
'category_id'
]
==
CategoryId
.
TableFootnote
:
tables_footnote
.
append
(
bbox
)
elif
layout_det
[
'category_id'
]
==
CategoryId
.
ImageBody
:
imgs_body
.
append
(
bbox
)
elif
layout_det
[
'category_id'
]
==
CategoryId
.
ImageCaption
:
imgs_caption
.
append
(
bbox
)
elif
layout_det
[
'category_id'
]
==
CategoryId
.
InterlineEquation_YOLO
:
interequations
.
append
(
bbox
)
elif
layout_det
[
'category_id'
]
==
CategoryId
.
Abandon
:
page_dropped_list
.
append
(
bbox
)
elif
layout_det
[
'category_id'
]
==
CategoryId
.
ImageFootnote
:
imgs_footnote
.
append
(
bbox
)
tables_body_list
.
append
(
tables_body
)
tables_caption_list
.
append
(
tables_caption
)
tables_footnote_list
.
append
(
tables_footnote
)
imgs_body_list
.
append
(
imgs_body
)
imgs_caption_list
.
append
(
imgs_caption
)
titles_list
.
append
(
titles
)
texts_list
.
append
(
texts
)
interequations_list
.
append
(
interequations
)
dropped_bbox_list
.
append
(
page_dropped_list
)
imgs_footnote_list
.
append
(
imgs_footnote
)
for
i
in
range
(
len
(
dataset
)):
page
=
dataset
.
get_page
(
i
)
draw_bbox_with_number
(
i
,
dropped_bbox_list
,
page
,
[
158
,
158
,
158
],
True
)
# color !
draw_bbox_with_number
(
i
,
tables_body_list
,
page
,
[
204
,
204
,
0
],
True
)
draw_bbox_with_number
(
i
,
tables_caption_list
,
page
,
[
255
,
255
,
102
],
True
)
draw_bbox_with_number
(
i
,
tables_footnote_list
,
page
,
[
229
,
255
,
204
],
True
)
draw_bbox_with_number
(
i
,
imgs_body_list
,
page
,
[
153
,
255
,
51
],
True
)
draw_bbox_with_number
(
i
,
imgs_caption_list
,
page
,
[
102
,
178
,
255
],
True
)
draw_bbox_with_number
(
i
,
imgs_footnote_list
,
page
,
[
255
,
178
,
102
],
True
)
draw_bbox_with_number
(
i
,
titles_list
,
page
,
[
102
,
102
,
255
],
True
)
draw_bbox_with_number
(
i
,
texts_list
,
page
,
[
153
,
0
,
76
],
True
)
draw_bbox_with_number
(
i
,
interequations_list
,
page
,
[
0
,
255
,
0
],
True
)
# Save the PDF
dataset
.
dump_to_file
(
f
'
{
out_path
}
/
{
filename
}
'
)
def
draw_line_sort_bbox
(
pdf_info
,
pdf_bytes
,
out_path
,
filename
):
layout_bbox_list
=
[]
for
page
in
pdf_info
:
page_line_list
=
[]
for
block
in
page
[
'preproc_blocks'
]:
if
block
[
'type'
]
in
[
BlockType
.
Text
]:
for
line
in
block
[
'lines'
]:
bbox
=
line
[
'bbox'
]
index
=
line
[
'index'
]
page_line_list
.
append
({
'index'
:
index
,
'bbox'
:
bbox
})
elif
block
[
'type'
]
in
[
BlockType
.
Title
,
BlockType
.
InterlineEquation
]:
if
'virtual_lines'
in
block
:
if
len
(
block
[
'virtual_lines'
])
>
0
and
block
[
'virtual_lines'
][
0
].
get
(
'index'
,
None
)
is
not
None
:
for
line
in
block
[
'virtual_lines'
]:
bbox
=
line
[
'bbox'
]
index
=
line
[
'index'
]
page_line_list
.
append
({
'index'
:
index
,
'bbox'
:
bbox
})
else
:
for
line
in
block
[
'lines'
]:
bbox
=
line
[
'bbox'
]
index
=
line
[
'index'
]
page_line_list
.
append
({
'index'
:
index
,
'bbox'
:
bbox
})
elif
block
[
'type'
]
in
[
BlockType
.
Image
,
BlockType
.
Table
]:
for
sub_block
in
block
[
'blocks'
]:
if
sub_block
[
'type'
]
in
[
BlockType
.
ImageBody
,
BlockType
.
TableBody
]:
if
len
(
sub_block
[
'virtual_lines'
])
>
0
and
sub_block
[
'virtual_lines'
][
0
].
get
(
'index'
,
None
)
is
not
None
:
for
line
in
sub_block
[
'virtual_lines'
]:
bbox
=
line
[
'bbox'
]
index
=
line
[
'index'
]
page_line_list
.
append
({
'index'
:
index
,
'bbox'
:
bbox
})
else
:
for
line
in
sub_block
[
'lines'
]:
bbox
=
line
[
'bbox'
]
index
=
line
[
'index'
]
page_line_list
.
append
({
'index'
:
index
,
'bbox'
:
bbox
})
elif
sub_block
[
'type'
]
in
[
BlockType
.
ImageCaption
,
BlockType
.
TableCaption
,
BlockType
.
ImageFootnote
,
BlockType
.
TableFootnote
]:
for
line
in
sub_block
[
'lines'
]:
bbox
=
line
[
'bbox'
]
index
=
line
[
'index'
]
page_line_list
.
append
({
'index'
:
index
,
'bbox'
:
bbox
})
sorted_bboxes
=
sorted
(
page_line_list
,
key
=
lambda
x
:
x
[
'index'
])
layout_bbox_list
.
append
(
sorted_bbox
[
'bbox'
]
for
sorted_bbox
in
sorted_bboxes
)
pdf_docs
=
fitz
.
open
(
'pdf'
,
pdf_bytes
)
for
i
,
page
in
enumerate
(
pdf_docs
):
draw_bbox_with_number
(
i
,
layout_bbox_list
,
page
,
[
255
,
0
,
0
],
False
)
pdf_docs
.
save
(
f
'
{
out_path
}
/
{
filename
}
'
)
def
draw_char_bbox
(
pdf_bytes
,
out_path
,
filename
):
pdf_docs
=
fitz
.
open
(
'pdf'
,
pdf_bytes
)
for
i
,
page
in
enumerate
(
pdf_docs
):
for
block
in
page
.
get_text
(
'rawdict'
,
flags
=
fitz
.
TEXT_PRESERVE_LIGATURES
|
fitz
.
TEXT_PRESERVE_WHITESPACE
|
fitz
.
TEXT_MEDIABOX_CLIP
)[
'blocks'
]:
for
line
in
block
[
'lines'
]:
for
span
in
line
[
'spans'
]:
for
char
in
span
[
'chars'
]:
char_bbox
=
char
[
'bbox'
]
page
.
draw_rect
(
char_bbox
,
color
=
[
1
,
0
,
0
],
fill
=
None
,
fill_opacity
=
1
,
width
=
0.3
,
overlay
=
True
,)
pdf_docs
.
save
(
f
'
{
out_path
}
/
{
filename
}
'
)
magic_pdf/libs/hash_utils.py
deleted
100644 → 0
View file @
f5016508
import
hashlib
def
compute_md5
(
file_bytes
):
hasher
=
hashlib
.
md5
()
hasher
.
update
(
file_bytes
)
return
hasher
.
hexdigest
().
upper
()
def
compute_sha256
(
input_string
):
hasher
=
hashlib
.
sha256
()
# 在Python3中,需要将字符串转化为字节对象才能被哈希函数处理
input_bytes
=
input_string
.
encode
(
'utf-8'
)
hasher
.
update
(
input_bytes
)
return
hasher
.
hexdigest
()
magic_pdf/libs/json_compressor.py
deleted
100644 → 0
View file @
f5016508
import
json
import
brotli
import
base64
class
JsonCompressor
:
@
staticmethod
def
compress_json
(
data
):
"""
Compress a json object and encode it with base64
"""
json_str
=
json
.
dumps
(
data
)
json_bytes
=
json_str
.
encode
(
'utf-8'
)
compressed
=
brotli
.
compress
(
json_bytes
,
quality
=
6
)
compressed_str
=
base64
.
b64encode
(
compressed
).
decode
(
'utf-8'
)
# convert bytes to string
return
compressed_str
@
staticmethod
def
decompress_json
(
compressed_str
):
"""
Decode the base64 string and decompress the json object
"""
compressed
=
base64
.
b64decode
(
compressed_str
.
encode
(
'utf-8'
))
# convert string to bytes
decompressed_bytes
=
brotli
.
decompress
(
compressed
)
json_str
=
decompressed_bytes
.
decode
(
'utf-8'
)
data
=
json
.
loads
(
json_str
)
return
data
magic_pdf/libs/language.py
deleted
100644 → 0
View file @
f5016508
import
os
import
unicodedata
if
not
os
.
getenv
(
"FTLANG_CACHE"
):
current_file_path
=
os
.
path
.
abspath
(
__file__
)
current_dir
=
os
.
path
.
dirname
(
current_file_path
)
root_dir
=
os
.
path
.
dirname
(
current_dir
)
ftlang_cache_dir
=
os
.
path
.
join
(
root_dir
,
'resources'
,
'fasttext-langdetect'
)
os
.
environ
[
"FTLANG_CACHE"
]
=
str
(
ftlang_cache_dir
)
# print(os.getenv("FTLANG_CACHE"))
from
fast_langdetect
import
detect_language
def
remove_invalid_surrogates
(
text
):
# 移除无效的 UTF-16 代理对
return
''
.
join
(
c
for
c
in
text
if
not
(
0xD800
<=
ord
(
c
)
<=
0xDFFF
))
def
detect_lang
(
text
:
str
)
->
str
:
if
len
(
text
)
==
0
:
return
""
text
=
text
.
replace
(
"
\n
"
,
""
)
text
=
remove_invalid_surrogates
(
text
)
# print(text)
try
:
lang_upper
=
detect_language
(
text
)
except
:
html_no_ctrl_chars
=
''
.
join
([
l
for
l
in
text
if
unicodedata
.
category
(
l
)[
0
]
not
in
[
'C'
,
]])
lang_upper
=
detect_language
(
html_no_ctrl_chars
)
try
:
lang
=
lang_upper
.
lower
()
except
:
lang
=
""
return
lang
if
__name__
==
'__main__'
:
print
(
os
.
getenv
(
"FTLANG_CACHE"
))
print
(
detect_lang
(
"This is a test."
))
print
(
detect_lang
(
"<html>This is a test</html>"
))
print
(
detect_lang
(
"这个是中文测试。"
))
print
(
detect_lang
(
"<html>这个是中文测试。</html>"
))
print
(
detect_lang
(
"〖
\ud835\udc46\ud835
〗这是个包含utf-16的中文测试"
))
\ No newline at end of file
magic_pdf/libs/local_math.py
deleted
100644 → 0
View file @
f5016508
def
float_gt
(
a
,
b
):
if
0.0001
>=
abs
(
a
-
b
):
return
False
return
a
>
b
def
float_equal
(
a
,
b
):
if
0.0001
>=
abs
(
a
-
b
):
return
True
return
False
\ No newline at end of file
magic_pdf/libs/markdown_utils.py
deleted
100644 → 0
View file @
f5016508
def
ocr_escape_special_markdown_char
(
content
):
"""
转义正文里对markdown语法有特殊意义的字符
"""
special_chars
=
[
"*"
,
"`"
,
"~"
,
"$"
]
for
char
in
special_chars
:
content
=
content
.
replace
(
char
,
"
\\
"
+
char
)
return
content
magic_pdf/libs/path_utils.py
deleted
100644 → 0
View file @
f5016508
def
remove_non_official_s3_args
(
s3path
):
"""
example: s3://abc/xxxx.json?bytes=0,81350 ==> s3://abc/xxxx.json
"""
arr
=
s3path
.
split
(
"?"
)
return
arr
[
0
]
def
parse_s3path
(
s3path
:
str
):
# from s3pathlib import S3Path
# p = S3Path(remove_non_official_s3_args(s3path))
# return p.bucket, p.key
s3path
=
remove_non_official_s3_args
(
s3path
).
strip
()
if
s3path
.
startswith
((
's3://'
,
's3a://'
)):
prefix
,
path
=
s3path
.
split
(
'://'
,
1
)
bucket_name
,
key
=
path
.
split
(
'/'
,
1
)
return
bucket_name
,
key
elif
s3path
.
startswith
(
'/'
):
raise
ValueError
(
"The provided path starts with '/'. This does not conform to a valid S3 path format."
)
else
:
raise
ValueError
(
"Invalid S3 path format. Expected 's3://bucket-name/key' or 's3a://bucket-name/key'."
)
def
parse_s3_range_params
(
s3path
:
str
):
"""
example: s3://abc/xxxx.json?bytes=0,81350 ==> [0, 81350]
"""
arr
=
s3path
.
split
(
"?bytes="
)
if
len
(
arr
)
==
1
:
return
None
return
arr
[
1
].
split
(
","
)
magic_pdf/libs/pdf_check.py
deleted
100644 → 0
View file @
f5016508
import
fitz
import
numpy
as
np
from
loguru
import
logger
import
re
from
io
import
BytesIO
from
pdfminer.high_level
import
extract_text
from
pdfminer.layout
import
LAParams
def
calculate_sample_count
(
total_page
:
int
):
"""
根据总页数和采样率计算采样页面的数量。
"""
select_page_cnt
=
min
(
10
,
total_page
)
return
select_page_cnt
def
extract_pages
(
src_pdf_bytes
:
bytes
)
->
fitz
.
Document
:
pdf_docs
=
fitz
.
open
(
"pdf"
,
src_pdf_bytes
)
total_page
=
len
(
pdf_docs
)
if
total_page
==
0
:
# 如果PDF没有页面,直接返回空文档
logger
.
warning
(
"PDF is empty, return empty document"
)
return
fitz
.
Document
()
select_page_cnt
=
calculate_sample_count
(
total_page
)
page_num
=
np
.
random
.
choice
(
total_page
,
select_page_cnt
,
replace
=
False
)
sample_docs
=
fitz
.
Document
()
try
:
for
index
in
page_num
:
sample_docs
.
insert_pdf
(
pdf_docs
,
from_page
=
int
(
index
),
to_page
=
int
(
index
))
except
Exception
as
e
:
logger
.
exception
(
e
)
return
sample_docs
def
detect_invalid_chars
(
src_pdf_bytes
:
bytes
)
->
bool
:
""""
检测PDF中是否包含非法字符
"""
'''pdfminer比较慢,需要先随机抽取10页左右的sample'''
sample_docs
=
extract_pages
(
src_pdf_bytes
)
sample_pdf_bytes
=
sample_docs
.
tobytes
()
sample_pdf_file_like_object
=
BytesIO
(
sample_pdf_bytes
)
laparams
=
LAParams
(
line_overlap
=
0.5
,
char_margin
=
2.0
,
line_margin
=
0.5
,
word_margin
=
0.1
,
boxes_flow
=
None
,
detect_vertical
=
False
,
all_texts
=
False
,
)
text
=
extract_text
(
pdf_file
=
sample_pdf_file_like_object
,
laparams
=
laparams
)
text
=
text
.
replace
(
"
\n
"
,
""
)
# logger.info(text)
'''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
cid_pattern
=
re
.
compile
(
r
'\(cid:\d+\)'
)
matches
=
cid_pattern
.
findall
(
text
)
cid_count
=
len
(
matches
)
cid_len
=
sum
(
len
(
match
)
for
match
in
matches
)
text_len
=
len
(
text
)
if
text_len
==
0
:
cid_chars_radio
=
0
else
:
cid_chars_radio
=
cid_count
/
(
cid_count
+
text_len
-
cid_len
)
logger
.
info
(
f
"cid_count:
{
cid_count
}
, text_len:
{
text_len
}
, cid_chars_radio:
{
cid_chars_radio
}
"
)
'''当一篇文章存在5%以上的文本是乱码时,认为该文档为乱码文档'''
if
cid_chars_radio
>
0.05
:
return
False
# 乱码文档
else
:
return
True
# 正常文档
def
count_replacement_characters
(
text
:
str
)
->
int
:
"""
统计字符串中 0xfffd 字符的数量。
"""
return
text
.
count
(
'
\ufffd
'
)
def
detect_invalid_chars_by_pymupdf
(
src_pdf_bytes
:
bytes
)
->
bool
:
sample_docs
=
extract_pages
(
src_pdf_bytes
)
doc_text
=
""
for
page
in
sample_docs
:
page_text
=
page
.
get_text
(
'text'
,
flags
=
fitz
.
TEXT_PRESERVE_WHITESPACE
|
fitz
.
TEXT_MEDIABOX_CLIP
)
doc_text
+=
page_text
text_len
=
len
(
doc_text
)
uffd_count
=
count_replacement_characters
(
doc_text
)
if
text_len
==
0
:
uffd_chars_radio
=
0
else
:
uffd_chars_radio
=
uffd_count
/
text_len
logger
.
info
(
f
"uffd_count:
{
uffd_count
}
, text_len:
{
text_len
}
, uffd_chars_radio:
{
uffd_chars_radio
}
"
)
'''当一篇文章存在1%以上的文本是乱码时,认为该文档为乱码文档'''
if
uffd_chars_radio
>
0.01
:
return
False
# 乱码文档
else
:
return
True
# 正常文档
\ No newline at end of file
magic_pdf/libs/pdf_image_tools.py
deleted
100644 → 0
View file @
f5016508
from
io
import
BytesIO
import
cv2
import
fitz
import
numpy
as
np
from
PIL
import
Image
from
magic_pdf.data.data_reader_writer
import
DataWriter
from
magic_pdf.libs.commons
import
join_path
from
magic_pdf.libs.hash_utils
import
compute_sha256
def
cut_image
(
bbox
:
tuple
,
page_num
:
int
,
page
:
fitz
.
Page
,
return_path
,
imageWriter
:
DataWriter
):
"""从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 save_path:需要同时支持s3和本地,
图片存放在save_path下,文件名是:
{page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。"""
# 拼接文件名
filename
=
f
'
{
page_num
}
_
{
int
(
bbox
[
0
])
}
_
{
int
(
bbox
[
1
])
}
_
{
int
(
bbox
[
2
])
}
_
{
int
(
bbox
[
3
])
}
'
# 老版本返回不带bucket的路径
img_path
=
join_path
(
return_path
,
filename
)
if
return_path
is
not
None
else
None
# 新版本生成平铺路径
img_hash256_path
=
f
'
{
compute_sha256
(
img_path
)
}
.jpg'
# 将坐标转换为fitz.Rect对象
rect
=
fitz
.
Rect
(
*
bbox
)
# 配置缩放倍数为3倍
zoom
=
fitz
.
Matrix
(
3
,
3
)
# 截取图片
pix
=
page
.
get_pixmap
(
clip
=
rect
,
matrix
=
zoom
)
byte_data
=
pix
.
tobytes
(
output
=
'jpeg'
,
jpg_quality
=
95
)
imageWriter
.
write
(
img_hash256_path
,
byte_data
)
return
img_hash256_path
def
cut_image_to_pil_image
(
bbox
:
tuple
,
page
:
fitz
.
Page
,
mode
=
"pillow"
):
# 将坐标转换为fitz.Rect对象
rect
=
fitz
.
Rect
(
*
bbox
)
# 配置缩放倍数为3倍
zoom
=
fitz
.
Matrix
(
3
,
3
)
# 截取图片
pix
=
page
.
get_pixmap
(
clip
=
rect
,
matrix
=
zoom
)
if
mode
==
"cv2"
:
# 直接转换为numpy数组供cv2使用
img_array
=
np
.
frombuffer
(
pix
.
samples
,
dtype
=
np
.
uint8
).
reshape
(
pix
.
height
,
pix
.
width
,
pix
.
n
)
# PyMuPDF使用RGB顺序,而cv2使用BGR顺序
if
pix
.
n
==
3
or
pix
.
n
==
4
:
image_result
=
cv2
.
cvtColor
(
img_array
,
cv2
.
COLOR_RGB2BGR
)
else
:
image_result
=
img_array
elif
mode
==
"pillow"
:
# 将字节数据转换为文件对象
image_file
=
BytesIO
(
pix
.
tobytes
(
output
=
'png'
))
# 使用 Pillow 打开图像
image_result
=
Image
.
open
(
image_file
)
else
:
raise
ValueError
(
f
"mode:
{
mode
}
is not supported."
)
return
image_result
\ No newline at end of file
magic_pdf/libs/performance_stats.py
deleted
100644 → 0
View file @
f5016508
import
time
import
functools
from
collections
import
defaultdict
from
typing
import
Dict
,
List
class
PerformanceStats
:
"""性能统计类,用于收集和展示方法执行时间"""
_stats
:
Dict
[
str
,
List
[
float
]]
=
defaultdict
(
list
)
@
classmethod
def
add_execution_time
(
cls
,
func_name
:
str
,
execution_time
:
float
):
"""添加执行时间记录"""
cls
.
_stats
[
func_name
].
append
(
execution_time
)
@
classmethod
def
get_stats
(
cls
)
->
Dict
[
str
,
dict
]:
"""获取统计结果"""
results
=
{}
for
func_name
,
times
in
cls
.
_stats
.
items
():
results
[
func_name
]
=
{
'count'
:
len
(
times
),
'total_time'
:
sum
(
times
),
'avg_time'
:
sum
(
times
)
/
len
(
times
),
'min_time'
:
min
(
times
),
'max_time'
:
max
(
times
)
}
return
results
@
classmethod
def
print_stats
(
cls
):
"""打印统计结果"""
stats
=
cls
.
get_stats
()
print
(
"
\n
性能统计结果:"
)
print
(
"-"
*
80
)
print
(
f
"
{
'方法名'
:
<
40
}
{
'调用次数'
:
>
8
}
{
'总时间(s)'
:
>
12
}
{
'平均时间(s)'
:
>
12
}
"
)
print
(
"-"
*
80
)
for
func_name
,
data
in
stats
.
items
():
print
(
f
"
{
func_name
:
<
40
}
{
data
[
'count'
]:
8
d
}
{
data
[
'total_time'
]:
12.6
f
}
{
data
[
'avg_time'
]:
12.6
f
}
"
)
def
measure_time
(
func
):
"""测量方法执行时间的装饰器"""
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
start_time
=
time
.
time
()
result
=
func
(
*
args
,
**
kwargs
)
execution_time
=
time
.
time
()
-
start_time
# 获取更详细的函数标识
if
hasattr
(
func
,
"__self__"
):
# 实例方法
class_name
=
func
.
__self__
.
__class__
.
__name__
full_name
=
f
"
{
class_name
}
.
{
func
.
__name__
}
"
elif
hasattr
(
func
,
"__qualname__"
):
# 类方法或静态方法
full_name
=
func
.
__qualname__
else
:
module_name
=
func
.
__module__
full_name
=
f
"
{
module_name
}
.
{
func
.
__name__
}
"
PerformanceStats
.
add_execution_time
(
full_name
,
execution_time
)
return
result
return
wrapper
\ No newline at end of file
magic_pdf/libs/safe_filename.py
deleted
100644 → 0
View file @
f5016508
import
os
def
sanitize_filename
(
filename
,
replacement
=
"_"
):
if
os
.
name
==
'nt'
:
invalid_chars
=
'<>:"|?*'
for
char
in
invalid_chars
:
filename
=
filename
.
replace
(
char
,
replacement
)
return
filename
magic_pdf/libs/version.py
deleted
100644 → 0
View file @
f5016508
__version__
=
"1.3.12"
magic_pdf/model/__init__.py
deleted
100644 → 0
View file @
f5016508
__use_inside_model__
=
True
__model_mode__
=
'full'
\ No newline at end of file
magic_pdf/model/batch_analyze.py
deleted
100644 → 0
View file @
f5016508
import
time
import
cv2
from
loguru
import
logger
from
tqdm
import
tqdm
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.model.sub_modules.model_init
import
AtomModelSingleton
from
magic_pdf.model.sub_modules.model_utils
import
(
clean_vram
,
crop_img
,
get_res_list_from_layout_res
,
get_coords_and_area
)
from
magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils
import
(
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
YOLO_LAYOUT_BASE_BATCH_SIZE
=
1
MFD_BASE_BATCH_SIZE
=
1
MFR_BASE_BATCH_SIZE
=
16
class
BatchAnalyze
:
def
__init__
(
self
,
model_manager
,
batch_ratio
:
int
,
show_log
,
layout_model
,
formula_enable
,
table_enable
):
self
.
model_manager
=
model_manager
self
.
batch_ratio
=
batch_ratio
self
.
show_log
=
show_log
self
.
layout_model
=
layout_model
self
.
formula_enable
=
formula_enable
self
.
table_enable
=
table_enable
def
__call__
(
self
,
images_with_extra_info
:
list
)
->
list
:
if
len
(
images_with_extra_info
)
==
0
:
return
[]
images_layout_res
=
[]
layout_start_time
=
time
.
time
()
self
.
model
=
self
.
model_manager
.
get_model
(
ocr
=
True
,
show_log
=
self
.
show_log
,
lang
=
None
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
,
)
images
=
[
image
for
image
,
_
,
_
in
images_with_extra_info
]
if
self
.
model
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
# layoutlmv3
for
image
in
images
:
layout_res
=
self
.
model
.
layout_model
(
image
,
ignore_catids
=
[])
images_layout_res
.
append
(
layout_res
)
elif
self
.
model
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
# doclayout_yolo
layout_images
=
[]
for
image_index
,
image
in
enumerate
(
images
):
layout_images
.
append
(
image
)
images_layout_res
+=
self
.
model
.
layout_model
.
batch_predict
(
# layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
layout_images
,
YOLO_LAYOUT_BASE_BATCH_SIZE
)
# logger.info(
# f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
# )
if
self
.
model
.
apply_formula
:
# 公式检测
mfd_start_time
=
time
.
time
()
images_mfd_res
=
self
.
model
.
mfd_model
.
batch_predict
(
# images, self.batch_ratio * MFD_BASE_BATCH_SIZE
images
,
MFD_BASE_BATCH_SIZE
)
# logger.info(
# f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
# )
# 公式识别
mfr_start_time
=
time
.
time
()
images_formula_list
=
self
.
model
.
mfr_model
.
batch_predict
(
images_mfd_res
,
images
,
batch_size
=
self
.
batch_ratio
*
MFR_BASE_BATCH_SIZE
,
)
mfr_count
=
0
for
image_index
in
range
(
len
(
images
)):
images_layout_res
[
image_index
]
+=
images_formula_list
[
image_index
]
mfr_count
+=
len
(
images_formula_list
[
image_index
])
# logger.info(
# f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
# )
# 清理显存
# clean_vram(self.model.device, vram_threshold=8)
ocr_res_list_all_page
=
[]
table_res_list_all_page
=
[]
for
index
in
range
(
len
(
images
)):
_
,
ocr_enable
,
_lang
=
images_with_extra_info
[
index
]
layout_res
=
images_layout_res
[
index
]
np_array_img
=
images
[
index
]
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
=
(
get_res_list_from_layout_res
(
layout_res
)
)
ocr_res_list_all_page
.
append
({
'ocr_res_list'
:
ocr_res_list
,
'lang'
:
_lang
,
'ocr_enable'
:
ocr_enable
,
'np_array_img'
:
np_array_img
,
'single_page_mfdetrec_res'
:
single_page_mfdetrec_res
,
'layout_res'
:
layout_res
,
})
for
table_res
in
table_res_list
:
table_img
,
_
=
crop_img
(
table_res
,
np_array_img
)
table_res_list_all_page
.
append
({
'table_res'
:
table_res
,
'lang'
:
_lang
,
'table_img'
:
table_img
,
})
# 文本框检测
det_start
=
time
.
time
()
det_count
=
0
# for ocr_res_list_dict in ocr_res_list_all_page:
for
ocr_res_list_dict
in
tqdm
(
ocr_res_list_all_page
,
desc
=
"OCR-det Predict"
):
# Process each area that requires OCR processing
_lang
=
ocr_res_list_dict
[
'lang'
]
# Get OCR results for this language's images
atom_model_manager
=
AtomModelSingleton
()
ocr_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
'ocr'
,
ocr_show_log
=
False
,
det_db_box_thresh
=
0.3
,
lang
=
_lang
)
for
res
in
ocr_res_list_dict
[
'ocr_res_list'
]:
new_image
,
useful_list
=
crop_img
(
res
,
ocr_res_list_dict
[
'np_array_img'
],
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
ocr_res_list_dict
[
'single_page_mfdetrec_res'
],
useful_list
)
# OCR-det
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
ocr_res
=
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
,
rec
=
False
)[
0
]
# Integration results
if
ocr_res
:
ocr_result_list
=
get_ocr_result_list
(
ocr_res
,
useful_list
,
ocr_res_list_dict
[
'ocr_enable'
],
new_image
,
_lang
)
if
res
[
"category_id"
]
==
3
:
# ocr_result_list中所有bbox的面积之和
ocr_res_area
=
sum
(
get_coords_and_area
(
ocr_res_item
)[
4
]
for
ocr_res_item
in
ocr_result_list
if
'poly'
in
ocr_res_item
)
# 求ocr_res_area和res的面积的比值
res_area
=
get_coords_and_area
(
res
)[
4
]
if
res_area
>
0
:
ratio
=
ocr_res_area
/
res_area
if
ratio
>
0.25
:
res
[
"category_id"
]
=
1
else
:
continue
ocr_res_list_dict
[
'layout_res'
].
extend
(
ocr_result_list
)
# det_count += len(ocr_res_list_dict['ocr_res_list'])
# logger.info(f'ocr-det time: {round(time.time()-det_start, 2)}, image num: {det_count}')
# 表格识别 table recognition
if
self
.
model
.
apply_table
:
table_start
=
time
.
time
()
# for table_res_list_dict in table_res_list_all_page:
for
table_res_dict
in
tqdm
(
table_res_list_all_page
,
desc
=
"Table Predict"
):
_lang
=
table_res_dict
[
'lang'
]
atom_model_manager
=
AtomModelSingleton
()
table_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
'table'
,
table_model_name
=
'rapid_table'
,
table_model_path
=
''
,
table_max_time
=
400
,
device
=
'cpu'
,
lang
=
_lang
,
table_sub_model_name
=
'slanet_plus'
)
html_code
,
table_cell_bboxes
,
logic_points
,
elapse
=
table_model
.
predict
(
table_res_dict
[
'table_img'
])
# 判断是否返回正常
if
html_code
:
expected_ending
=
html_code
.
strip
().
endswith
(
'</html>'
)
or
html_code
.
strip
().
endswith
(
'</table>'
)
if
expected_ending
:
table_res_dict
[
'table_res'
][
'html'
]
=
html_code
else
:
logger
.
warning
(
'table recognition processing fails, not found expected HTML table end'
)
else
:
logger
.
warning
(
'table recognition processing fails, not get html return'
)
# logger.info(f'table time: {round(time.time() - table_start, 2)}, image num: {len(table_res_list_all_page)}')
# Create dictionaries to store items by language
need_ocr_lists_by_lang
=
{}
# Dict of lists for each language
img_crop_lists_by_lang
=
{}
# Dict of lists for each language
for
layout_res
in
images_layout_res
:
for
layout_res_item
in
layout_res
:
if
layout_res_item
[
'category_id'
]
in
[
15
]:
if
'np_img'
in
layout_res_item
and
'lang'
in
layout_res_item
:
lang
=
layout_res_item
[
'lang'
]
# Initialize lists for this language if not exist
if
lang
not
in
need_ocr_lists_by_lang
:
need_ocr_lists_by_lang
[
lang
]
=
[]
img_crop_lists_by_lang
[
lang
]
=
[]
# Add to the appropriate language-specific lists
need_ocr_lists_by_lang
[
lang
].
append
(
layout_res_item
)
img_crop_lists_by_lang
[
lang
].
append
(
layout_res_item
[
'np_img'
])
# Remove the fields after adding to lists
layout_res_item
.
pop
(
'np_img'
)
layout_res_item
.
pop
(
'lang'
)
if
len
(
img_crop_lists_by_lang
)
>
0
:
# Process OCR by language
rec_time
=
0
rec_start
=
time
.
time
()
total_processed
=
0
# Process each language separately
for
lang
,
img_crop_list
in
img_crop_lists_by_lang
.
items
():
if
len
(
img_crop_list
)
>
0
:
# Get OCR results for this language's images
atom_model_manager
=
AtomModelSingleton
()
ocr_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
'ocr'
,
ocr_show_log
=
False
,
det_db_box_thresh
=
0.3
,
lang
=
lang
)
ocr_res_list
=
ocr_model
.
ocr
(
img_crop_list
,
det
=
False
,
tqdm_enable
=
True
)[
0
]
# Verify we have matching counts
assert
len
(
ocr_res_list
)
==
len
(
need_ocr_lists_by_lang
[
lang
]),
f
'ocr_res_list:
{
len
(
ocr_res_list
)
}
, need_ocr_list:
{
len
(
need_ocr_lists_by_lang
[
lang
])
}
for lang:
{
lang
}
'
# Process OCR results for this language
for
index
,
layout_res_item
in
enumerate
(
need_ocr_lists_by_lang
[
lang
]):
ocr_text
,
ocr_score
=
ocr_res_list
[
index
]
layout_res_item
[
'text'
]
=
ocr_text
layout_res_item
[
'score'
]
=
float
(
f
"
{
ocr_score
:.
3
f
}
"
)
total_processed
+=
len
(
img_crop_list
)
rec_time
+=
time
.
time
()
-
rec_start
# logger.info(f'ocr-rec time: {round(rec_time, 2)}, total images processed: {total_processed}')
return
images_layout_res
magic_pdf/model/doc_analyze_by_custom_model.py
deleted
100644 → 0
View file @
f5016508
import
os
import
time
import
numpy
as
np
import
torch
os
.
environ
[
'FLAGS_npu_jit_compile'
]
=
'0'
# 关闭paddle的jit编译
os
.
environ
[
'FLAGS_use_stride_kernel'
]
=
'0'
os
.
environ
[
'PYTORCH_ENABLE_MPS_FALLBACK'
]
=
'1'
# 让mps可以fallback
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
from
loguru
import
logger
from
magic_pdf.model.sub_modules.model_utils
import
get_vram
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
import
magic_pdf.model
as
model_config
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.libs.config_reader
import
(
get_device
,
get_formula_config
,
get_layout_config
,
get_local_models_dir
,
get_table_recog_config
)
from
magic_pdf.model.model_list
import
MODEL
class
ModelSingleton
:
_instance
=
None
_models
=
{}
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
().
__new__
(
cls
)
return
cls
.
_instance
def
get_model
(
self
,
ocr
:
bool
,
show_log
:
bool
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
):
key
=
(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
if
key
not
in
self
.
_models
:
self
.
_models
[
key
]
=
custom_model_init
(
ocr
=
ocr
,
show_log
=
show_log
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
,
)
return
self
.
_models
[
key
]
def
custom_model_init
(
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
):
model
=
None
if
model_config
.
__model_mode__
==
'lite'
:
logger
.
warning
(
'The Lite mode is provided for developers to conduct testing only, and the output quality is '
'not guaranteed to be reliable.'
)
model
=
MODEL
.
Paddle
elif
model_config
.
__model_mode__
==
'full'
:
model
=
MODEL
.
PEK
if
model_config
.
__use_inside_model__
:
model_init_start
=
time
.
time
()
if
model
==
MODEL
.
Paddle
:
from
magic_pdf.model.pp_structure_v2
import
CustomPaddleModel
custom_model
=
CustomPaddleModel
(
ocr
=
ocr
,
show_log
=
show_log
,
lang
=
lang
)
elif
model
==
MODEL
.
PEK
:
from
magic_pdf.model.pdf_extract_kit
import
CustomPEKModel
# 从配置文件读取model-dir和device
local_models_dir
=
get_local_models_dir
()
device
=
get_device
()
layout_config
=
get_layout_config
()
if
layout_model
is
not
None
:
layout_config
[
'model'
]
=
layout_model
formula_config
=
get_formula_config
()
if
formula_enable
is
not
None
:
formula_config
[
'enable'
]
=
formula_enable
table_config
=
get_table_recog_config
()
if
table_enable
is
not
None
:
table_config
[
'enable'
]
=
table_enable
model_input
=
{
'ocr'
:
ocr
,
'show_log'
:
show_log
,
'models_dir'
:
local_models_dir
,
'device'
:
device
,
'table_config'
:
table_config
,
'layout_config'
:
layout_config
,
'formula_config'
:
formula_config
,
'lang'
:
lang
,
}
custom_model
=
CustomPEKModel
(
**
model_input
)
else
:
logger
.
error
(
'Not allow model_name!'
)
exit
(
1
)
model_init_cost
=
time
.
time
()
-
model_init_start
logger
.
info
(
f
'model init cost:
{
model_init_cost
}
'
)
else
:
logger
.
error
(
'use_inside_model is False, not allow to use inside model'
)
exit
(
1
)
return
custom_model
def
doc_analyze
(
dataset
:
Dataset
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
):
end_page_id
=
(
end_page_id
if
end_page_id
is
not
None
and
end_page_id
>=
0
else
len
(
dataset
)
-
1
)
MIN_BATCH_INFERENCE_SIZE
=
int
(
os
.
environ
.
get
(
'MINERU_MIN_BATCH_INFERENCE_SIZE'
,
200
))
images
=
[]
page_wh_list
=
[]
for
index
in
range
(
len
(
dataset
)):
if
start_page_id
<=
index
<=
end_page_id
:
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
'img'
])
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
images_with_extra_info
=
[(
images
[
index
],
ocr
,
dataset
.
_lang
)
for
index
in
range
(
len
(
images
))]
if
len
(
images
)
>=
MIN_BATCH_INFERENCE_SIZE
:
batch_size
=
MIN_BATCH_INFERENCE_SIZE
batch_images
=
[
images_with_extra_info
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images_with_extra_info
),
batch_size
)]
else
:
batch_images
=
[
images_with_extra_info
]
results
=
[]
processed_images_count
=
0
for
index
,
batch_image
in
enumerate
(
batch_images
):
processed_images_count
+=
len
(
batch_image
)
logger
.
info
(
f
'Batch
{
index
+
1
}
/
{
len
(
batch_images
)
}
:
{
processed_images_count
}
pages/
{
len
(
images_with_extra_info
)
}
pages'
)
result
=
may_batch_image_analyze
(
batch_image
,
ocr
,
show_log
,
layout_model
,
formula_enable
,
table_enable
)
results
.
extend
(
result
)
model_json
=
[]
for
index
in
range
(
len
(
dataset
)):
if
start_page_id
<=
index
<=
end_page_id
:
result
=
results
.
pop
(
0
)
page_width
,
page_height
=
page_wh_list
.
pop
(
0
)
else
:
result
=
[]
page_height
=
0
page_width
=
0
page_info
=
{
'page_no'
:
index
,
'width'
:
page_width
,
'height'
:
page_height
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
model_json
.
append
(
page_dict
)
from
magic_pdf.operators.models
import
InferenceResult
return
InferenceResult
(
model_json
,
dataset
)
def
batch_doc_analyze
(
datasets
:
list
[
Dataset
],
parse_method
:
str
=
'auto'
,
show_log
:
bool
=
False
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
):
MIN_BATCH_INFERENCE_SIZE
=
int
(
os
.
environ
.
get
(
'MINERU_MIN_BATCH_INFERENCE_SIZE'
,
100
))
batch_size
=
MIN_BATCH_INFERENCE_SIZE
page_wh_list
=
[]
images_with_extra_info
=
[]
for
dataset
in
datasets
:
ocr
=
False
if
parse_method
==
'auto'
:
if
dataset
.
classify
()
==
SupportedPdfParseMethod
.
TXT
:
ocr
=
False
elif
dataset
.
classify
()
==
SupportedPdfParseMethod
.
OCR
:
ocr
=
True
elif
parse_method
==
'ocr'
:
ocr
=
True
elif
parse_method
==
'txt'
:
ocr
=
False
_lang
=
dataset
.
_lang
for
index
in
range
(
len
(
dataset
)):
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
images_with_extra_info
.
append
((
img_dict
[
'img'
],
ocr
,
_lang
))
batch_images
=
[
images_with_extra_info
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images_with_extra_info
),
batch_size
)]
results
=
[]
processed_images_count
=
0
for
index
,
batch_image
in
enumerate
(
batch_images
):
processed_images_count
+=
len
(
batch_image
)
logger
.
info
(
f
'Batch
{
index
+
1
}
/
{
len
(
batch_images
)
}
:
{
processed_images_count
}
pages/
{
len
(
images_with_extra_info
)
}
pages'
)
result
=
may_batch_image_analyze
(
batch_image
,
True
,
show_log
,
layout_model
,
formula_enable
,
table_enable
)
results
.
extend
(
result
)
infer_results
=
[]
from
magic_pdf.operators.models
import
InferenceResult
for
index
in
range
(
len
(
datasets
)):
dataset
=
datasets
[
index
]
model_json
=
[]
for
i
in
range
(
len
(
dataset
)):
result
=
results
.
pop
(
0
)
page_width
,
page_height
=
page_wh_list
.
pop
(
0
)
page_info
=
{
'page_no'
:
i
,
'width'
:
page_width
,
'height'
:
page_height
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
model_json
.
append
(
page_dict
)
infer_results
.
append
(
InferenceResult
(
model_json
,
dataset
))
return
infer_results
def
may_batch_image_analyze
(
images_with_extra_info
:
list
[(
np
.
ndarray
,
bool
,
str
)],
ocr
:
bool
,
show_log
:
bool
=
False
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
# os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
from
magic_pdf.model.batch_analyze
import
BatchAnalyze
model_manager
=
ModelSingleton
()
# images = [image for image, _, _ in images_with_extra_info]
batch_ratio
=
1
device
=
get_device
()
if
str
(
device
).
startswith
(
'npu'
):
import
torch_npu
if
torch_npu
.
npu
.
is_available
():
torch
.
npu
.
set_compile_mode
(
jit_compile
=
False
)
if
str
(
device
).
startswith
(
'npu'
)
or
str
(
device
).
startswith
(
'cuda'
):
vram
=
get_vram
(
device
)
if
vram
is
not
None
:
gpu_memory
=
int
(
os
.
getenv
(
'VIRTUAL_VRAM_SIZE'
,
round
(
vram
)))
if
gpu_memory
>=
16
:
batch_ratio
=
16
elif
gpu_memory
>=
12
:
batch_ratio
=
8
elif
gpu_memory
>=
8
:
batch_ratio
=
4
elif
gpu_memory
>=
6
:
batch_ratio
=
2
else
:
batch_ratio
=
1
logger
.
info
(
f
'gpu_memory:
{
gpu_memory
}
GB, batch_ratio:
{
batch_ratio
}
'
)
else
:
# Default batch_ratio when VRAM can't be determined
batch_ratio
=
1
logger
.
info
(
f
'Could not determine GPU memory, using default batch_ratio:
{
batch_ratio
}
'
)
# doc_analyze_start = time.time()
batch_model
=
BatchAnalyze
(
model_manager
,
batch_ratio
,
show_log
,
layout_model
,
formula_enable
,
table_enable
)
results
=
batch_model
(
images_with_extra_info
)
# gc_start = time.time()
clean_memory
(
get_device
())
# gc_time = round(time.time() - gc_start, 2)
# logger.debug(f'gc time: {gc_time}')
# doc_analyze_time = round(time.time() - doc_analyze_start, 2)
# doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
# logger.debug(
# f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
# f' speed: {doc_analyze_speed} pages/second'
# )
return
results
\ No newline at end of file
magic_pdf/model/magic_model.py
deleted
100644 → 0
View file @
f5016508
import
enum
from
magic_pdf.config.model_block_type
import
ModelBlockTypeEnum
from
magic_pdf.config.ocr_content_type
import
CategoryId
,
ContentType
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.boxbase
import
(
_is_in
,
bbox_distance
,
bbox_relative_pos
,
calculate_iou
)
from
magic_pdf.libs.coordinate_transform
import
get_scale_ratio
from
magic_pdf.pre_proc.remove_bbox_overlap
import
_remove_overlap_between_bbox
CAPATION_OVERLAP_AREA_RATIO
=
0.6
MERGE_BOX_OVERLAP_AREA_RATIO
=
1.1
class
PosRelationEnum
(
enum
.
Enum
):
LEFT
=
'left'
RIGHT
=
'right'
UP
=
'up'
BOTTOM
=
'bottom'
ALL
=
'all'
class
MagicModel
:
"""每个函数没有得到元素的时候返回空list."""
def
__fix_axis
(
self
):
for
model_page_info
in
self
.
__model_list
:
need_remove_list
=
[]
page_no
=
model_page_info
[
'page_info'
][
'page_no'
]
horizontal_scale_ratio
,
vertical_scale_ratio
=
get_scale_ratio
(
model_page_info
,
self
.
__docs
.
get_page
(
page_no
)
)
layout_dets
=
model_page_info
[
'layout_dets'
]
for
layout_det
in
layout_dets
:
if
layout_det
.
get
(
'bbox'
)
is
not
None
:
# 兼容直接输出bbox的模型数据,如paddle
x0
,
y0
,
x1
,
y1
=
layout_det
[
'bbox'
]
else
:
# 兼容直接输出poly的模型数据,如xxx
x0
,
y0
,
_
,
_
,
x1
,
y1
,
_
,
_
=
layout_det
[
'poly'
]
bbox
=
[
int
(
x0
/
horizontal_scale_ratio
),
int
(
y0
/
vertical_scale_ratio
),
int
(
x1
/
horizontal_scale_ratio
),
int
(
y1
/
vertical_scale_ratio
),
]
layout_det
[
'bbox'
]
=
bbox
# 删除高度或者宽度小于等于0的spans
if
bbox
[
2
]
-
bbox
[
0
]
<=
0
or
bbox
[
3
]
-
bbox
[
1
]
<=
0
:
need_remove_list
.
append
(
layout_det
)
for
need_remove
in
need_remove_list
:
layout_dets
.
remove
(
need_remove
)
def
__fix_by_remove_low_confidence
(
self
):
for
model_page_info
in
self
.
__model_list
:
need_remove_list
=
[]
layout_dets
=
model_page_info
[
'layout_dets'
]
for
layout_det
in
layout_dets
:
if
layout_det
[
'score'
]
<=
0.05
:
need_remove_list
.
append
(
layout_det
)
else
:
continue
for
need_remove
in
need_remove_list
:
layout_dets
.
remove
(
need_remove
)
def
__fix_by_remove_high_iou_and_low_confidence
(
self
):
for
model_page_info
in
self
.
__model_list
:
need_remove_list
=
[]
layout_dets
=
model_page_info
[
'layout_dets'
]
for
layout_det1
in
layout_dets
:
for
layout_det2
in
layout_dets
:
if
layout_det1
==
layout_det2
:
continue
if
layout_det1
[
'category_id'
]
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
]
and
layout_det2
[
'category_id'
]
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]:
if
(
calculate_iou
(
layout_det1
[
'bbox'
],
layout_det2
[
'bbox'
])
>
0.9
):
if
layout_det1
[
'score'
]
<
layout_det2
[
'score'
]:
layout_det_need_remove
=
layout_det1
else
:
layout_det_need_remove
=
layout_det2
if
layout_det_need_remove
not
in
need_remove_list
:
need_remove_list
.
append
(
layout_det_need_remove
)
else
:
continue
else
:
continue
for
need_remove
in
need_remove_list
:
layout_dets
.
remove
(
need_remove
)
def
__init__
(
self
,
model_list
:
list
,
docs
:
Dataset
):
self
.
__model_list
=
model_list
self
.
__docs
=
docs
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
self
.
__fix_axis
()
"""删除置信度特别低的模型数据(<0.05),提高质量"""
self
.
__fix_by_remove_low_confidence
()
"""删除高iou(>0.9)数据中置信度较低的那个"""
self
.
__fix_by_remove_high_iou_and_low_confidence
()
self
.
__fix_footnote
()
def
_bbox_distance
(
self
,
bbox1
,
bbox2
):
left
,
right
,
bottom
,
top
=
bbox_relative_pos
(
bbox1
,
bbox2
)
flags
=
[
left
,
right
,
bottom
,
top
]
count
=
sum
([
1
if
v
else
0
for
v
in
flags
])
if
count
>
1
:
return
float
(
'inf'
)
if
left
or
right
:
l1
=
bbox1
[
3
]
-
bbox1
[
1
]
l2
=
bbox2
[
3
]
-
bbox2
[
1
]
else
:
l1
=
bbox1
[
2
]
-
bbox1
[
0
]
l2
=
bbox2
[
2
]
-
bbox2
[
0
]
if
l2
>
l1
and
(
l2
-
l1
)
/
l1
>
0.3
:
return
float
(
'inf'
)
return
bbox_distance
(
bbox1
,
bbox2
)
def
__fix_footnote
(
self
):
# 3: figure, 5: table, 7: footnote
for
model_page_info
in
self
.
__model_list
:
footnotes
=
[]
figures
=
[]
tables
=
[]
for
obj
in
model_page_info
[
'layout_dets'
]:
if
obj
[
'category_id'
]
==
7
:
footnotes
.
append
(
obj
)
elif
obj
[
'category_id'
]
==
3
:
figures
.
append
(
obj
)
elif
obj
[
'category_id'
]
==
5
:
tables
.
append
(
obj
)
if
len
(
footnotes
)
*
len
(
figures
)
==
0
:
continue
dis_figure_footnote
=
{}
dis_table_footnote
=
{}
for
i
in
range
(
len
(
footnotes
)):
for
j
in
range
(
len
(
figures
)):
pos_flag_count
=
sum
(
list
(
map
(
lambda
x
:
1
if
x
else
0
,
bbox_relative_pos
(
footnotes
[
i
][
'bbox'
],
figures
[
j
][
'bbox'
]
),
)
)
)
if
pos_flag_count
>
1
:
continue
dis_figure_footnote
[
i
]
=
min
(
self
.
_bbox_distance
(
figures
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]),
dis_figure_footnote
.
get
(
i
,
float
(
'inf'
)),
)
for
i
in
range
(
len
(
footnotes
)):
for
j
in
range
(
len
(
tables
)):
pos_flag_count
=
sum
(
list
(
map
(
lambda
x
:
1
if
x
else
0
,
bbox_relative_pos
(
footnotes
[
i
][
'bbox'
],
tables
[
j
][
'bbox'
]
),
)
)
)
if
pos_flag_count
>
1
:
continue
dis_table_footnote
[
i
]
=
min
(
self
.
_bbox_distance
(
tables
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]),
dis_table_footnote
.
get
(
i
,
float
(
'inf'
)),
)
for
i
in
range
(
len
(
footnotes
)):
if
i
not
in
dis_figure_footnote
:
continue
if
dis_table_footnote
.
get
(
i
,
float
(
'inf'
))
>
dis_figure_footnote
[
i
]:
footnotes
[
i
][
'category_id'
]
=
CategoryId
.
ImageFootnote
def
__reduct_overlap
(
self
,
bboxes
):
N
=
len
(
bboxes
)
keep
=
[
True
]
*
N
for
i
in
range
(
N
):
for
j
in
range
(
N
):
if
i
==
j
:
continue
if
_is_in
(
bboxes
[
i
][
'bbox'
],
bboxes
[
j
][
'bbox'
]):
keep
[
i
]
=
False
return
[
bboxes
[
i
]
for
i
in
range
(
N
)
if
keep
[
i
]]
def
__tie_up_category_by_distance_v2
(
self
,
page_no
:
int
,
subject_category_id
:
int
,
object_category_id
:
int
,
priority_pos
:
PosRelationEnum
,
):
"""_summary_
Args:
page_no (int): _description_
subject_category_id (int): _description_
object_category_id (int): _description_
priority_pos (PosRelationEnum): _description_
Returns:
_type_: _description_
"""
AXIS_MULPLICITY
=
0.5
subjects
=
self
.
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
subject_category_id
,
self
.
__model_list
[
page_no
][
'layout_dets'
],
),
)
)
)
objects
=
self
.
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
object_category_id
,
self
.
__model_list
[
page_no
][
'layout_dets'
],
),
)
)
)
M
=
len
(
objects
)
subjects
.
sort
(
key
=
lambda
x
:
x
[
'bbox'
][
0
]
**
2
+
x
[
'bbox'
][
1
]
**
2
)
objects
.
sort
(
key
=
lambda
x
:
x
[
'bbox'
][
0
]
**
2
+
x
[
'bbox'
][
1
]
**
2
)
sub_obj_map_h
=
{
i
:
[]
for
i
in
range
(
len
(
subjects
))}
dis_by_directions
=
{
'top'
:
[[
-
1
,
float
(
'inf'
)]]
*
M
,
'bottom'
:
[[
-
1
,
float
(
'inf'
)]]
*
M
,
'left'
:
[[
-
1
,
float
(
'inf'
)]]
*
M
,
'right'
:
[[
-
1
,
float
(
'inf'
)]]
*
M
,
}
for
i
,
obj
in
enumerate
(
objects
):
l_x_axis
,
l_y_axis
=
(
obj
[
'bbox'
][
2
]
-
obj
[
'bbox'
][
0
],
obj
[
'bbox'
][
3
]
-
obj
[
'bbox'
][
1
],
)
axis_unit
=
min
(
l_x_axis
,
l_y_axis
)
for
j
,
sub
in
enumerate
(
subjects
):
bbox1
,
bbox2
,
_
=
_remove_overlap_between_bbox
(
objects
[
i
][
'bbox'
],
subjects
[
j
][
'bbox'
]
)
left
,
right
,
bottom
,
top
=
bbox_relative_pos
(
bbox1
,
bbox2
)
flags
=
[
left
,
right
,
bottom
,
top
]
if
sum
([
1
if
v
else
0
for
v
in
flags
])
>
1
:
continue
if
left
:
if
dis_by_directions
[
'left'
][
i
][
1
]
>
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]
):
dis_by_directions
[
'left'
][
i
]
=
[
j
,
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]),
]
if
right
:
if
dis_by_directions
[
'right'
][
i
][
1
]
>
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]
):
dis_by_directions
[
'right'
][
i
]
=
[
j
,
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]),
]
if
bottom
:
if
dis_by_directions
[
'bottom'
][
i
][
1
]
>
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]
):
dis_by_directions
[
'bottom'
][
i
]
=
[
j
,
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]),
]
if
top
:
if
dis_by_directions
[
'top'
][
i
][
1
]
>
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]
):
dis_by_directions
[
'top'
][
i
]
=
[
j
,
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]),
]
if
(
dis_by_directions
[
'top'
][
i
][
1
]
!=
float
(
'inf'
)
and
dis_by_directions
[
'bottom'
][
i
][
1
]
!=
float
(
'inf'
)
and
priority_pos
in
(
PosRelationEnum
.
BOTTOM
,
PosRelationEnum
.
UP
)
):
RATIO
=
3
if
(
abs
(
dis_by_directions
[
'top'
][
i
][
1
]
-
dis_by_directions
[
'bottom'
][
i
][
1
]
)
<
RATIO
*
axis_unit
):
if
priority_pos
==
PosRelationEnum
.
BOTTOM
:
sub_obj_map_h
[
dis_by_directions
[
'bottom'
][
i
][
0
]].
append
(
i
)
else
:
sub_obj_map_h
[
dis_by_directions
[
'top'
][
i
][
0
]].
append
(
i
)
continue
if
dis_by_directions
[
'left'
][
i
][
1
]
!=
float
(
'inf'
)
or
dis_by_directions
[
'right'
][
i
][
1
]
!=
float
(
'inf'
):
if
dis_by_directions
[
'left'
][
i
][
1
]
!=
float
(
'inf'
)
and
dis_by_directions
[
'right'
][
i
][
1
]
!=
float
(
'inf'
):
if
AXIS_MULPLICITY
*
axis_unit
>=
abs
(
dis_by_directions
[
'left'
][
i
][
1
]
-
dis_by_directions
[
'right'
][
i
][
1
]
):
left_sub_bbox
=
subjects
[
dis_by_directions
[
'left'
][
i
][
0
]][
'bbox'
]
right_sub_bbox
=
subjects
[
dis_by_directions
[
'right'
][
i
][
0
]][
'bbox'
]
left_sub_bbox_y_axis
=
left_sub_bbox
[
3
]
-
left_sub_bbox
[
1
]
right_sub_bbox_y_axis
=
right_sub_bbox
[
3
]
-
right_sub_bbox
[
1
]
if
(
abs
(
left_sub_bbox_y_axis
-
l_y_axis
)
+
dis_by_directions
[
'left'
][
i
][
0
]
>
abs
(
right_sub_bbox_y_axis
-
l_y_axis
)
+
dis_by_directions
[
'right'
][
i
][
0
]
):
left_or_right
=
dis_by_directions
[
'right'
][
i
]
else
:
left_or_right
=
dis_by_directions
[
'left'
][
i
]
else
:
left_or_right
=
dis_by_directions
[
'left'
][
i
]
if
left_or_right
[
1
]
>
dis_by_directions
[
'right'
][
i
][
1
]:
left_or_right
=
dis_by_directions
[
'right'
][
i
]
else
:
left_or_right
=
dis_by_directions
[
'left'
][
i
]
if
left_or_right
[
1
]
==
float
(
'inf'
):
left_or_right
=
dis_by_directions
[
'right'
][
i
]
else
:
left_or_right
=
[
-
1
,
float
(
'inf'
)]
if
dis_by_directions
[
'top'
][
i
][
1
]
!=
float
(
'inf'
)
or
dis_by_directions
[
'bottom'
][
i
][
1
]
!=
float
(
'inf'
):
if
dis_by_directions
[
'top'
][
i
][
1
]
!=
float
(
'inf'
)
and
dis_by_directions
[
'bottom'
][
i
][
1
]
!=
float
(
'inf'
):
if
AXIS_MULPLICITY
*
axis_unit
>=
abs
(
dis_by_directions
[
'top'
][
i
][
1
]
-
dis_by_directions
[
'bottom'
][
i
][
1
]
):
top_bottom
=
subjects
[
dis_by_directions
[
'bottom'
][
i
][
0
]][
'bbox'
]
bottom_top
=
subjects
[
dis_by_directions
[
'top'
][
i
][
0
]][
'bbox'
]
top_bottom_x_axis
=
top_bottom
[
2
]
-
top_bottom
[
0
]
bottom_top_x_axis
=
bottom_top
[
2
]
-
bottom_top
[
0
]
if
(
abs
(
top_bottom_x_axis
-
l_x_axis
)
+
dis_by_directions
[
'bottom'
][
i
][
1
]
>
abs
(
bottom_top_x_axis
-
l_x_axis
)
+
dis_by_directions
[
'top'
][
i
][
1
]
):
top_or_bottom
=
dis_by_directions
[
'top'
][
i
]
else
:
top_or_bottom
=
dis_by_directions
[
'bottom'
][
i
]
else
:
top_or_bottom
=
dis_by_directions
[
'top'
][
i
]
if
top_or_bottom
[
1
]
>
dis_by_directions
[
'bottom'
][
i
][
1
]:
top_or_bottom
=
dis_by_directions
[
'bottom'
][
i
]
else
:
top_or_bottom
=
dis_by_directions
[
'top'
][
i
]
if
top_or_bottom
[
1
]
==
float
(
'inf'
):
top_or_bottom
=
dis_by_directions
[
'bottom'
][
i
]
else
:
top_or_bottom
=
[
-
1
,
float
(
'inf'
)]
if
left_or_right
[
1
]
!=
float
(
'inf'
)
or
top_or_bottom
[
1
]
!=
float
(
'inf'
):
if
left_or_right
[
1
]
!=
float
(
'inf'
)
and
top_or_bottom
[
1
]
!=
float
(
'inf'
):
if
AXIS_MULPLICITY
*
axis_unit
>=
abs
(
left_or_right
[
1
]
-
top_or_bottom
[
1
]
):
y_axis_bbox
=
subjects
[
left_or_right
[
0
]][
'bbox'
]
x_axis_bbox
=
subjects
[
top_or_bottom
[
0
]][
'bbox'
]
if
(
abs
((
x_axis_bbox
[
2
]
-
x_axis_bbox
[
0
])
-
l_x_axis
)
/
l_x_axis
>
abs
((
y_axis_bbox
[
3
]
-
y_axis_bbox
[
1
])
-
l_y_axis
)
/
l_y_axis
):
sub_obj_map_h
[
left_or_right
[
0
]].
append
(
i
)
else
:
sub_obj_map_h
[
top_or_bottom
[
0
]].
append
(
i
)
else
:
if
left_or_right
[
1
]
>
top_or_bottom
[
1
]:
sub_obj_map_h
[
top_or_bottom
[
0
]].
append
(
i
)
else
:
sub_obj_map_h
[
left_or_right
[
0
]].
append
(
i
)
else
:
if
left_or_right
[
1
]
!=
float
(
'inf'
):
sub_obj_map_h
[
left_or_right
[
0
]].
append
(
i
)
else
:
sub_obj_map_h
[
top_or_bottom
[
0
]].
append
(
i
)
ret
=
[]
for
i
in
sub_obj_map_h
.
keys
():
ret
.
append
(
{
'sub_bbox'
:
{
'bbox'
:
subjects
[
i
][
'bbox'
],
'score'
:
subjects
[
i
][
'score'
],
},
'obj_bboxes'
:
[
{
'score'
:
objects
[
j
][
'score'
],
'bbox'
:
objects
[
j
][
'bbox'
]}
for
j
in
sub_obj_map_h
[
i
]
],
'sub_idx'
:
i
,
}
)
return
ret
def
__tie_up_category_by_distance_v3
(
self
,
page_no
:
int
,
subject_category_id
:
int
,
object_category_id
:
int
,
priority_pos
:
PosRelationEnum
,
):
subjects
=
self
.
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
subject_category_id
,
self
.
__model_list
[
page_no
][
'layout_dets'
],
),
)
)
)
objects
=
self
.
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
object_category_id
,
self
.
__model_list
[
page_no
][
'layout_dets'
],
),
)
)
)
ret
=
[]
N
,
M
=
len
(
subjects
),
len
(
objects
)
subjects
.
sort
(
key
=
lambda
x
:
x
[
'bbox'
][
0
]
**
2
+
x
[
'bbox'
][
1
]
**
2
)
objects
.
sort
(
key
=
lambda
x
:
x
[
'bbox'
][
0
]
**
2
+
x
[
'bbox'
][
1
]
**
2
)
OBJ_IDX_OFFSET
=
10000
SUB_BIT_KIND
,
OBJ_BIT_KIND
=
0
,
1
all_boxes_with_idx
=
[(
i
,
SUB_BIT_KIND
,
sub
[
'bbox'
][
0
],
sub
[
'bbox'
][
1
])
for
i
,
sub
in
enumerate
(
subjects
)]
+
[(
i
+
OBJ_IDX_OFFSET
,
OBJ_BIT_KIND
,
obj
[
'bbox'
][
0
],
obj
[
'bbox'
][
1
])
for
i
,
obj
in
enumerate
(
objects
)]
seen_idx
=
set
()
seen_sub_idx
=
set
()
while
N
>
len
(
seen_sub_idx
):
candidates
=
[]
for
idx
,
kind
,
x0
,
y0
in
all_boxes_with_idx
:
if
idx
in
seen_idx
:
continue
candidates
.
append
((
idx
,
kind
,
x0
,
y0
))
if
len
(
candidates
)
==
0
:
break
left_x
=
min
([
v
[
2
]
for
v
in
candidates
])
top_y
=
min
([
v
[
3
]
for
v
in
candidates
])
candidates
.
sort
(
key
=
lambda
x
:
(
x
[
2
]
-
left_x
)
**
2
+
(
x
[
3
]
-
top_y
)
**
2
)
fst_idx
,
fst_kind
,
left_x
,
top_y
=
candidates
[
0
]
candidates
.
sort
(
key
=
lambda
x
:
(
x
[
2
]
-
left_x
)
**
2
+
(
x
[
3
]
-
top_y
)
**
2
)
nxt
=
None
for
i
in
range
(
1
,
len
(
candidates
)):
if
candidates
[
i
][
1
]
^
fst_kind
==
1
:
nxt
=
candidates
[
i
]
break
if
nxt
is
None
:
break
if
fst_kind
==
SUB_BIT_KIND
:
sub_idx
,
obj_idx
=
fst_idx
,
nxt
[
0
]
-
OBJ_IDX_OFFSET
else
:
sub_idx
,
obj_idx
=
nxt
[
0
],
fst_idx
-
OBJ_IDX_OFFSET
pair_dis
=
bbox_distance
(
subjects
[
sub_idx
][
'bbox'
],
objects
[
obj_idx
][
'bbox'
])
nearest_dis
=
float
(
'inf'
)
for
i
in
range
(
N
):
if
i
in
seen_idx
or
i
==
sub_idx
:
continue
nearest_dis
=
min
(
nearest_dis
,
bbox_distance
(
subjects
[
i
][
'bbox'
],
objects
[
obj_idx
][
'bbox'
]))
if
pair_dis
>=
3
*
nearest_dis
:
seen_idx
.
add
(
sub_idx
)
continue
seen_idx
.
add
(
sub_idx
)
seen_idx
.
add
(
obj_idx
+
OBJ_IDX_OFFSET
)
seen_sub_idx
.
add
(
sub_idx
)
ret
.
append
(
{
'sub_bbox'
:
{
'bbox'
:
subjects
[
sub_idx
][
'bbox'
],
'score'
:
subjects
[
sub_idx
][
'score'
],
},
'obj_bboxes'
:
[
{
'score'
:
objects
[
obj_idx
][
'score'
],
'bbox'
:
objects
[
obj_idx
][
'bbox'
]}
],
'sub_idx'
:
sub_idx
,
}
)
for
i
in
range
(
len
(
objects
)):
j
=
i
+
OBJ_IDX_OFFSET
if
j
in
seen_idx
:
continue
seen_idx
.
add
(
j
)
nearest_dis
,
nearest_sub_idx
=
float
(
'inf'
),
-
1
for
k
in
range
(
len
(
subjects
)):
dis
=
bbox_distance
(
objects
[
i
][
'bbox'
],
subjects
[
k
][
'bbox'
])
if
dis
<
nearest_dis
:
nearest_dis
=
dis
nearest_sub_idx
=
k
for
k
in
range
(
len
(
subjects
)):
if
k
!=
nearest_sub_idx
:
continue
if
k
in
seen_sub_idx
:
for
kk
in
range
(
len
(
ret
)):
if
ret
[
kk
][
'sub_idx'
]
==
k
:
ret
[
kk
][
'obj_bboxes'
].
append
({
'score'
:
objects
[
i
][
'score'
],
'bbox'
:
objects
[
i
][
'bbox'
]})
break
else
:
ret
.
append
(
{
'sub_bbox'
:
{
'bbox'
:
subjects
[
k
][
'bbox'
],
'score'
:
subjects
[
k
][
'score'
],
},
'obj_bboxes'
:
[
{
'score'
:
objects
[
i
][
'score'
],
'bbox'
:
objects
[
i
][
'bbox'
]}
],
'sub_idx'
:
k
,
}
)
seen_sub_idx
.
add
(
k
)
seen_idx
.
add
(
k
)
for
i
in
range
(
len
(
subjects
)):
if
i
in
seen_sub_idx
:
continue
ret
.
append
(
{
'sub_bbox'
:
{
'bbox'
:
subjects
[
i
][
'bbox'
],
'score'
:
subjects
[
i
][
'score'
],
},
'obj_bboxes'
:
[],
'sub_idx'
:
i
,
}
)
return
ret
def
get_imgs_v2
(
self
,
page_no
:
int
):
with_captions
=
self
.
__tie_up_category_by_distance_v3
(
page_no
,
3
,
4
,
PosRelationEnum
.
BOTTOM
)
with_footnotes
=
self
.
__tie_up_category_by_distance_v3
(
page_no
,
3
,
CategoryId
.
ImageFootnote
,
PosRelationEnum
.
ALL
)
ret
=
[]
for
v
in
with_captions
:
record
=
{
'image_body'
:
v
[
'sub_bbox'
],
'image_caption_list'
:
v
[
'obj_bboxes'
],
}
filter_idx
=
v
[
'sub_idx'
]
d
=
next
(
filter
(
lambda
x
:
x
[
'sub_idx'
]
==
filter_idx
,
with_footnotes
))
record
[
'image_footnote_list'
]
=
d
[
'obj_bboxes'
]
ret
.
append
(
record
)
return
ret
def
get_tables_v2
(
self
,
page_no
:
int
)
->
list
:
with_captions
=
self
.
__tie_up_category_by_distance_v3
(
page_no
,
5
,
6
,
PosRelationEnum
.
UP
)
with_footnotes
=
self
.
__tie_up_category_by_distance_v3
(
page_no
,
5
,
7
,
PosRelationEnum
.
ALL
)
ret
=
[]
for
v
in
with_captions
:
record
=
{
'table_body'
:
v
[
'sub_bbox'
],
'table_caption_list'
:
v
[
'obj_bboxes'
],
}
filter_idx
=
v
[
'sub_idx'
]
d
=
next
(
filter
(
lambda
x
:
x
[
'sub_idx'
]
==
filter_idx
,
with_footnotes
))
record
[
'table_footnote_list'
]
=
d
[
'obj_bboxes'
]
ret
.
append
(
record
)
return
ret
def
get_imgs
(
self
,
page_no
:
int
):
return
self
.
get_imgs_v2
(
page_no
)
def
get_tables
(
self
,
page_no
:
int
)
->
list
:
# 3个坐标, caption, table主体,table-note
return
self
.
get_tables_v2
(
page_no
)
def
get_equations
(
self
,
page_no
:
int
)
->
list
:
# 有坐标,也有字
inline_equations
=
self
.
__get_blocks_by_type
(
ModelBlockTypeEnum
.
EMBEDDING
.
value
,
page_no
,
[
'latex'
]
)
interline_equations
=
self
.
__get_blocks_by_type
(
ModelBlockTypeEnum
.
ISOLATED
.
value
,
page_no
,
[
'latex'
]
)
interline_equations_blocks
=
self
.
__get_blocks_by_type
(
ModelBlockTypeEnum
.
ISOLATE_FORMULA
.
value
,
page_no
)
return
inline_equations
,
interline_equations
,
interline_equations_blocks
def
get_discarded
(
self
,
page_no
:
int
)
->
list
:
# 自研模型,只有坐标
blocks
=
self
.
__get_blocks_by_type
(
ModelBlockTypeEnum
.
ABANDON
.
value
,
page_no
)
return
blocks
def
get_text_blocks
(
self
,
page_no
:
int
)
->
list
:
# 自研模型搞的,只有坐标,没有字
blocks
=
self
.
__get_blocks_by_type
(
ModelBlockTypeEnum
.
PLAIN_TEXT
.
value
,
page_no
)
return
blocks
def
get_title_blocks
(
self
,
page_no
:
int
)
->
list
:
# 自研模型,只有坐标,没字
blocks
=
self
.
__get_blocks_by_type
(
ModelBlockTypeEnum
.
TITLE
.
value
,
page_no
)
return
blocks
def
get_ocr_text
(
self
,
page_no
:
int
)
->
list
:
# paddle 搞的,有字也有坐标
text_spans
=
[]
model_page_info
=
self
.
__model_list
[
page_no
]
layout_dets
=
model_page_info
[
'layout_dets'
]
for
layout_det
in
layout_dets
:
if
layout_det
[
'category_id'
]
==
'15'
:
span
=
{
'bbox'
:
layout_det
[
'bbox'
],
'content'
:
layout_det
[
'text'
],
}
text_spans
.
append
(
span
)
return
text_spans
def
get_all_spans
(
self
,
page_no
:
int
)
->
list
:
def
remove_duplicate_spans
(
spans
):
new_spans
=
[]
for
span
in
spans
:
if
not
any
(
span
==
existing_span
for
existing_span
in
new_spans
):
new_spans
.
append
(
span
)
return
new_spans
all_spans
=
[]
model_page_info
=
self
.
__model_list
[
page_no
]
layout_dets
=
model_page_info
[
'layout_dets'
]
allow_category_id_list
=
[
3
,
5
,
13
,
14
,
15
]
"""当成span拼接的"""
# 3: 'image', # 图片
# 5: 'table', # 表格
# 13: 'inline_equation', # 行内公式
# 14: 'interline_equation', # 行间公式
# 15: 'text', # ocr识别文本
for
layout_det
in
layout_dets
:
category_id
=
layout_det
[
'category_id'
]
if
category_id
in
allow_category_id_list
:
span
=
{
'bbox'
:
layout_det
[
'bbox'
],
'score'
:
layout_det
[
'score'
]}
if
category_id
==
3
:
span
[
'type'
]
=
ContentType
.
Image
elif
category_id
==
5
:
# 获取table模型结果
latex
=
layout_det
.
get
(
'latex'
,
None
)
html
=
layout_det
.
get
(
'html'
,
None
)
if
latex
:
span
[
'latex'
]
=
latex
elif
html
:
span
[
'html'
]
=
html
span
[
'type'
]
=
ContentType
.
Table
elif
category_id
==
13
:
span
[
'content'
]
=
layout_det
[
'latex'
]
span
[
'type'
]
=
ContentType
.
InlineEquation
elif
category_id
==
14
:
span
[
'content'
]
=
layout_det
[
'latex'
]
span
[
'type'
]
=
ContentType
.
InterlineEquation
elif
category_id
==
15
:
span
[
'content'
]
=
layout_det
[
'text'
]
span
[
'type'
]
=
ContentType
.
Text
all_spans
.
append
(
span
)
return
remove_duplicate_spans
(
all_spans
)
def
get_page_size
(
self
,
page_no
:
int
):
# 获取页面宽高
# 获取当前页的page对象
page
=
self
.
__docs
.
get_page
(
page_no
).
get_page_info
()
# 获取当前页的宽高
page_w
=
page
.
w
page_h
=
page
.
h
return
page_w
,
page_h
def
__get_blocks_by_type
(
self
,
type
:
int
,
page_no
:
int
,
extra_col
:
list
[
str
]
=
[]
)
->
list
:
blocks
=
[]
for
page_dict
in
self
.
__model_list
:
layout_dets
=
page_dict
.
get
(
'layout_dets'
,
[])
page_info
=
page_dict
.
get
(
'page_info'
,
{})
page_number
=
page_info
.
get
(
'page_no'
,
-
1
)
if
page_no
!=
page_number
:
continue
for
item
in
layout_dets
:
category_id
=
item
.
get
(
'category_id'
,
-
1
)
bbox
=
item
.
get
(
'bbox'
,
None
)
if
category_id
==
type
:
block
=
{
'bbox'
:
bbox
,
'score'
:
item
.
get
(
'score'
),
}
for
col
in
extra_col
:
block
[
col
]
=
item
.
get
(
col
,
None
)
blocks
.
append
(
block
)
return
blocks
def
get_model_list
(
self
,
page_no
):
return
self
.
__model_list
[
page_no
]
magic_pdf/model/model_list.py
deleted
100644 → 0
View file @
f5016508
class
MODEL
:
Paddle
=
"pp_structure_v2"
PEK
=
"pdf_extract_kit"
class
AtomicModel
:
Layout
=
"layout"
MFD
=
"mfd"
MFR
=
"mfr"
OCR
=
"ocr"
Table
=
"table"
LangDetect
=
"langdetect"
magic_pdf/model/pdf_extract_kit.py
deleted
100644 → 0
View file @
f5016508
# flake8: noqa
import
os
import
time
import
cv2
import
torch
import
yaml
from
loguru
import
logger
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
from
magic_pdf.config.constants
import
*
from
magic_pdf.model.model_list
import
AtomicModel
from
magic_pdf.model.sub_modules.model_init
import
AtomModelSingleton
from
magic_pdf.model.sub_modules.model_utils
import
(
clean_vram
,
crop_img
,
get_res_list_from_layout_res
)
from
magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils
import
(
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
class
CustomPEKModel
:
def
__init__
(
self
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
**
kwargs
):
"""
======== model init ========
"""
# 获取当前文件(即 pdf_extract_kit.py)的绝对路径
current_file_path
=
os
.
path
.
abspath
(
__file__
)
# 获取当前文件所在的目录(model)
current_dir
=
os
.
path
.
dirname
(
current_file_path
)
# 上一级目录(magic_pdf)
root_dir
=
os
.
path
.
dirname
(
current_dir
)
# model_config目录
model_config_dir
=
os
.
path
.
join
(
root_dir
,
'resources'
,
'model_config'
)
# 构建 model_configs.yaml 文件的完整路径
config_path
=
os
.
path
.
join
(
model_config_dir
,
'model_configs.yaml'
)
with
open
(
config_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
self
.
configs
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
# 初始化解析配置
# layout config
self
.
layout_config
=
kwargs
.
get
(
'layout_config'
)
self
.
layout_model_name
=
self
.
layout_config
.
get
(
'model'
,
MODEL_NAME
.
DocLayout_YOLO
)
# formula config
self
.
formula_config
=
kwargs
.
get
(
'formula_config'
)
self
.
mfd_model_name
=
self
.
formula_config
.
get
(
'mfd_model'
,
MODEL_NAME
.
YOLO_V8_MFD
)
self
.
mfr_model_name
=
self
.
formula_config
.
get
(
'mfr_model'
,
MODEL_NAME
.
UniMerNet_v2_Small
)
self
.
apply_formula
=
self
.
formula_config
.
get
(
'enable'
,
True
)
# table config
self
.
table_config
=
kwargs
.
get
(
'table_config'
)
self
.
apply_table
=
self
.
table_config
.
get
(
'enable'
,
False
)
self
.
table_max_time
=
self
.
table_config
.
get
(
'max_time'
,
TABLE_MAX_TIME_VALUE
)
self
.
table_model_name
=
self
.
table_config
.
get
(
'model'
,
MODEL_NAME
.
RAPID_TABLE
)
self
.
table_sub_model_name
=
self
.
table_config
.
get
(
'sub_model'
,
None
)
# ocr config
self
.
apply_ocr
=
ocr
self
.
lang
=
kwargs
.
get
(
'lang'
,
None
)
logger
.
info
(
'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
'apply_table: {}, table_model: {}, lang: {}'
.
format
(
self
.
layout_model_name
,
self
.
apply_formula
,
self
.
apply_ocr
,
self
.
apply_table
,
self
.
table_model_name
,
self
.
lang
,
)
)
# 初始化解析方案
self
.
device
=
kwargs
.
get
(
'device'
,
'cpu'
)
logger
.
info
(
'using device: {}'
.
format
(
self
.
device
))
models_dir
=
kwargs
.
get
(
'models_dir'
,
os
.
path
.
join
(
root_dir
,
'resources'
,
'models'
)
)
logger
.
info
(
'using models_dir: {}'
.
format
(
models_dir
))
atom_model_manager
=
AtomModelSingleton
()
# 初始化公式识别
if
self
.
apply_formula
:
# 初始化公式检测模型
self
.
mfd_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
MFD
,
mfd_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
mfd_model_name
]
)
),
device
=
self
.
device
,
)
# 初始化公式解析模型
mfr_weight_dir
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
mfr_model_name
])
)
mfr_cfg_path
=
str
(
os
.
path
.
join
(
model_config_dir
,
'UniMERNet'
,
'demo.yaml'
))
self
.
mfr_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
MFR
,
mfr_weight_dir
=
mfr_weight_dir
,
mfr_cfg_path
=
mfr_cfg_path
,
device
=
self
.
device
,
)
# 初始化layout模型
if
self
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Layout
,
layout_model_name
=
MODEL_NAME
.
LAYOUTLMv3
,
layout_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
layout_model_name
]
)
),
layout_config_file
=
str
(
os
.
path
.
join
(
model_config_dir
,
'layoutlmv3'
,
'layoutlmv3_base_inference.yaml'
)
),
device
=
'cpu'
if
str
(
self
.
device
).
startswith
(
"mps"
)
else
self
.
device
,
)
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Layout
,
layout_model_name
=
MODEL_NAME
.
DocLayout_YOLO
,
doclayout_yolo_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
layout_model_name
]
)
),
device
=
self
.
device
,
)
# 初始化ocr
self
.
ocr_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
OCR
,
ocr_show_log
=
show_log
,
det_db_box_thresh
=
0.3
,
lang
=
self
.
lang
)
# init table model
if
self
.
apply_table
:
table_model_dir
=
self
.
configs
[
'weights'
][
self
.
table_model_name
]
self
.
table_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Table
,
table_model_name
=
self
.
table_model_name
,
table_model_path
=
str
(
os
.
path
.
join
(
models_dir
,
table_model_dir
)),
table_max_time
=
self
.
table_max_time
,
device
=
self
.
device
,
ocr_engine
=
self
.
ocr_model
,
table_sub_model_name
=
self
.
table_sub_model_name
)
logger
.
info
(
'DocAnalysis init done!'
)
def
__call__
(
self
,
image
):
# layout检测
layout_start
=
time
.
time
()
layout_res
=
[]
if
self
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
# layoutlmv3
layout_res
=
self
.
layout_model
(
image
,
ignore_catids
=
[])
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
layout_res
=
self
.
layout_model
.
predict
(
image
)
layout_cost
=
round
(
time
.
time
()
-
layout_start
,
2
)
logger
.
info
(
f
'layout detection time:
{
layout_cost
}
'
)
if
self
.
apply_formula
:
# 公式检测
mfd_start
=
time
.
time
()
mfd_res
=
self
.
mfd_model
.
predict
(
image
)
logger
.
info
(
f
'mfd time:
{
round
(
time
.
time
()
-
mfd_start
,
2
)
}
'
)
# 公式识别
mfr_start
=
time
.
time
()
formula_list
=
self
.
mfr_model
.
predict
(
mfd_res
,
image
)
layout_res
.
extend
(
formula_list
)
mfr_cost
=
round
(
time
.
time
()
-
mfr_start
,
2
)
logger
.
info
(
f
'formula nums:
{
len
(
formula_list
)
}
, mfr time:
{
mfr_cost
}
'
)
# 清理显存
clean_vram
(
self
.
device
,
vram_threshold
=
6
)
# 从layout_res中获取ocr区域、表格区域、公式区域
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
=
(
get_res_list_from_layout_res
(
layout_res
)
)
# ocr识别
ocr_start
=
time
.
time
()
# Process each area that requires OCR processing
for
res
in
ocr_res_list
:
new_image
,
useful_list
=
crop_img
(
res
,
image
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
# OCR recognition
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
if
self
.
apply_ocr
:
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
else
:
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
,
rec
=
False
)[
0
]
# Integration results
if
ocr_res
:
ocr_result_list
=
get_ocr_result_list
(
ocr_res
,
useful_list
)
layout_res
.
extend
(
ocr_result_list
)
ocr_cost
=
round
(
time
.
time
()
-
ocr_start
,
2
)
if
self
.
apply_ocr
:
logger
.
info
(
f
"ocr time:
{
ocr_cost
}
"
)
else
:
logger
.
info
(
f
"det time:
{
ocr_cost
}
"
)
# 表格识别 table recognition
if
self
.
apply_table
:
table_start
=
time
.
time
()
for
res
in
table_res_list
:
new_image
,
_
=
crop_img
(
res
,
image
)
single_table_start_time
=
time
.
time
()
html_code
=
None
if
self
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
with
torch
.
no_grad
():
table_result
=
self
.
table_model
.
predict
(
new_image
,
'html'
)
if
len
(
table_result
)
>
0
:
html_code
=
table_result
[
0
]
elif
self
.
table_model_name
==
MODEL_NAME
.
TABLE_MASTER
:
html_code
=
self
.
table_model
.
img2html
(
new_image
)
elif
self
.
table_model_name
==
MODEL_NAME
.
RAPID_TABLE
:
html_code
,
table_cell_bboxes
,
logic_points
,
elapse
=
self
.
table_model
.
predict
(
new_image
)
run_time
=
time
.
time
()
-
single_table_start_time
if
run_time
>
self
.
table_max_time
:
logger
.
warning
(
f
'table recognition processing exceeds max time
{
self
.
table_max_time
}
s'
)
# 判断是否返回正常
if
html_code
:
expected_ending
=
html_code
.
strip
().
endswith
(
'</html>'
)
or
html_code
.
strip
().
endswith
(
'</table>'
)
if
expected_ending
:
res
[
'html'
]
=
html_code
else
:
logger
.
warning
(
'table recognition processing fails, not found expected HTML table end'
)
else
:
logger
.
warning
(
'table recognition processing fails, not get html return'
)
logger
.
info
(
f
'table time:
{
round
(
time
.
time
()
-
table_start
,
2
)
}
'
)
return
layout_res
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment