Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
a01bd7ed
Unverified
Commit
a01bd7ed
authored
Mar 03, 2025
by
Xiaomeng Zhao
Committed by
GitHub
Mar 03, 2025
Browse files
Merge pull request #1821 from myhloli/dev
perf(mfr): improve Math Formula Recognition by sorting images by area
parents
058c349c
58b6ad8c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
136 additions
and
19 deletions
+136
-19
magic_pdf/libs/performance_stats.py
magic_pdf/libs/performance_stats.py
+54
-0
magic_pdf/model/doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+1
-7
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
+74
-9
magic_pdf/pdf_parse_union_core_v2.py
magic_pdf/pdf_parse_union_core_v2.py
+7
-3
No files found.
magic_pdf/libs/performance_stats.py
0 → 100644
View file @
a01bd7ed
import
time
import
functools
from
collections
import
defaultdict
from
typing
import
Dict
,
List
class
PerformanceStats
:
"""性能统计类,用于收集和展示方法执行时间"""
_stats
:
Dict
[
str
,
List
[
float
]]
=
defaultdict
(
list
)
@
classmethod
def
add_execution_time
(
cls
,
func_name
:
str
,
execution_time
:
float
):
"""添加执行时间记录"""
cls
.
_stats
[
func_name
].
append
(
execution_time
)
@
classmethod
def
get_stats
(
cls
)
->
Dict
[
str
,
dict
]:
"""获取统计结果"""
results
=
{}
for
func_name
,
times
in
cls
.
_stats
.
items
():
results
[
func_name
]
=
{
'count'
:
len
(
times
),
'total_time'
:
sum
(
times
),
'avg_time'
:
sum
(
times
)
/
len
(
times
),
'min_time'
:
min
(
times
),
'max_time'
:
max
(
times
)
}
return
results
@
classmethod
def
print_stats
(
cls
):
"""打印统计结果"""
stats
=
cls
.
get_stats
()
print
(
"
\n
性能统计结果:"
)
print
(
"-"
*
80
)
print
(
f
"
{
'方法名'
:
<
40
}
{
'调用次数'
:
>
8
}
{
'总时间(s)'
:
>
12
}
{
'平均时间(s)'
:
>
12
}
"
)
print
(
"-"
*
80
)
for
func_name
,
data
in
stats
.
items
():
print
(
f
"
{
func_name
:
<
40
}
{
data
[
'count'
]:
8
d
}
{
data
[
'total_time'
]:
12.6
f
}
{
data
[
'avg_time'
]:
12.6
f
}
"
)
def
measure_time
(
func
):
"""测量方法执行时间的装饰器"""
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
start_time
=
time
.
time
()
result
=
func
(
*
args
,
**
kwargs
)
execution_time
=
time
.
time
()
-
start_time
PerformanceStats
.
add_execution_time
(
func
.
__name__
,
execution_time
)
return
result
return
wrapper
\ No newline at end of file
magic_pdf/model/doc_analyze_by_custom_model.py
View file @
a01bd7ed
...
@@ -170,13 +170,7 @@ def doc_analyze(
...
@@ -170,13 +170,7 @@ def doc_analyze(
gpu_memory
=
int
(
os
.
getenv
(
"VIRTUAL_VRAM_SIZE"
,
round
(
get_vram
(
device
))))
gpu_memory
=
int
(
os
.
getenv
(
"VIRTUAL_VRAM_SIZE"
,
round
(
get_vram
(
device
))))
if
gpu_memory
is
not
None
and
gpu_memory
>=
8
:
if
gpu_memory
is
not
None
and
gpu_memory
>=
8
:
if
gpu_memory
>=
40
:
if
gpu_memory
>=
10
:
batch_ratio
=
32
elif
gpu_memory
>=
20
:
batch_ratio
=
16
elif
gpu_memory
>=
16
:
batch_ratio
=
8
elif
gpu_memory
>=
10
:
batch_ratio
=
4
batch_ratio
=
4
else
:
else
:
batch_ratio
=
2
batch_ratio
=
2
...
...
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
View file @
a01bd7ed
...
@@ -100,20 +100,61 @@ class UnimernetModel(object):
...
@@ -100,20 +100,61 @@ class UnimernetModel(object):
res
[
"latex"
]
=
latex_rm_whitespace
(
latex
)
res
[
"latex"
]
=
latex_rm_whitespace
(
latex
)
return
formula_list
return
formula_list
def
batch_predict
(
# def batch_predict(
self
,
images_mfd_res
:
list
,
images
:
list
,
batch_size
:
int
=
64
# self, images_mfd_res: list, images: list, batch_size: int = 64
)
->
list
:
# ) -> list:
# images_formula_list = []
# mf_image_list = []
# backfill_list = []
# for image_index in range(len(images_mfd_res)):
# mfd_res = images_mfd_res[image_index]
# pil_img = Image.fromarray(images[image_index])
# formula_list = []
#
# for xyxy, conf, cla in zip(
# mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
# ):
# xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
# new_item = {
# "category_id": 13 + int(cla.item()),
# "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
# "score": round(float(conf.item()), 2),
# "latex": "",
# }
# formula_list.append(new_item)
# bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
# mf_image_list.append(bbox_img)
#
# images_formula_list.append(formula_list)
# backfill_list += formula_list
#
# dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# mfr_res = []
# for mf_img in dataloader:
# mf_img = mf_img.to(self.device)
# with torch.no_grad():
# output = self.model.generate({"image": mf_img})
# mfr_res.extend(output["pred_str"])
# for res, latex in zip(backfill_list, mfr_res):
# res["latex"] = latex_rm_whitespace(latex)
# return images_formula_list
def
batch_predict
(
self
,
images_mfd_res
:
list
,
images
:
list
,
batch_size
:
int
=
64
)
->
list
:
images_formula_list
=
[]
images_formula_list
=
[]
mf_image_list
=
[]
mf_image_list
=
[]
backfill_list
=
[]
backfill_list
=
[]
image_info
=
[]
# Store (area, original_index, image) tuples
# Collect images with their original indices
for
image_index
in
range
(
len
(
images_mfd_res
)):
for
image_index
in
range
(
len
(
images_mfd_res
)):
mfd_res
=
images_mfd_res
[
image_index
]
mfd_res
=
images_mfd_res
[
image_index
]
pil_img
=
Image
.
fromarray
(
images
[
image_index
])
pil_img
=
Image
.
fromarray
(
images
[
image_index
])
formula_list
=
[]
formula_list
=
[]
for
xyxy
,
conf
,
cla
in
zip
(
for
idx
,
(
xyxy
,
conf
,
cla
)
in
enumerate
(
zip
(
mfd_res
.
boxes
.
xyxy
,
mfd_res
.
boxes
.
conf
,
mfd_res
.
boxes
.
cls
mfd_res
.
boxes
.
xyxy
,
mfd_res
.
boxes
.
conf
,
mfd_res
.
boxes
.
cls
):
)
):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
new_item
=
{
new_item
=
{
"category_id"
:
13
+
int
(
cla
.
item
()),
"category_id"
:
13
+
int
(
cla
.
item
()),
...
@@ -123,19 +164,43 @@ class UnimernetModel(object):
...
@@ -123,19 +164,43 @@ class UnimernetModel(object):
}
}
formula_list
.
append
(
new_item
)
formula_list
.
append
(
new_item
)
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
ymax
))
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
ymax
))
area
=
(
xmax
-
xmin
)
*
(
ymax
-
ymin
)
curr_idx
=
len
(
mf_image_list
)
image_info
.
append
((
area
,
curr_idx
,
bbox_img
))
mf_image_list
.
append
(
bbox_img
)
mf_image_list
.
append
(
bbox_img
)
images_formula_list
.
append
(
formula_list
)
images_formula_list
.
append
(
formula_list
)
backfill_list
+=
formula_list
backfill_list
+=
formula_list
dataset
=
MathDataset
(
mf_image_list
,
transform
=
self
.
mfr_transform
)
# Stable sort by area
image_info
.
sort
(
key
=
lambda
x
:
x
[
0
])
# sort by area
sorted_indices
=
[
x
[
1
]
for
x
in
image_info
]
sorted_images
=
[
x
[
2
]
for
x
in
image_info
]
# Create mapping for results
index_mapping
=
{
new_idx
:
old_idx
for
new_idx
,
old_idx
in
enumerate
(
sorted_indices
)}
# Create dataset with sorted images
dataset
=
MathDataset
(
sorted_images
,
transform
=
self
.
mfr_transform
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
num_workers
=
0
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
num_workers
=
0
)
# Process batches and store results
mfr_res
=
[]
mfr_res
=
[]
for
mf_img
in
dataloader
:
for
mf_img
in
dataloader
:
mf_img
=
mf_img
.
to
(
self
.
device
)
mf_img
=
mf_img
.
to
(
self
.
device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
self
.
model
.
generate
({
"image"
:
mf_img
})
output
=
self
.
model
.
generate
({
"image"
:
mf_img
})
mfr_res
.
extend
(
output
[
"pred_str"
])
mfr_res
.
extend
(
output
[
"pred_str"
])
for
res
,
latex
in
zip
(
backfill_list
,
mfr_res
):
res
[
"latex"
]
=
latex_rm_whitespace
(
latex
)
# Restore original order
unsorted_results
=
[
""
]
*
len
(
mfr_res
)
for
new_idx
,
latex
in
enumerate
(
mfr_res
):
original_idx
=
index_mapping
[
new_idx
]
unsorted_results
[
original_idx
]
=
latex_rm_whitespace
(
latex
)
# Fill results back
for
res
,
latex
in
zip
(
backfill_list
,
unsorted_results
):
res
[
"latex"
]
=
latex
return
images_formula_list
return
images_formula_list
magic_pdf/pdf_parse_union_core_v2.py
View file @
a01bd7ed
...
@@ -21,9 +21,12 @@ from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_l
...
@@ -21,9 +21,12 @@ from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_l
from
magic_pdf.libs.convert_utils
import
dict_to_list
from
magic_pdf.libs.convert_utils
import
dict_to_list
from
magic_pdf.libs.hash_utils
import
compute_md5
from
magic_pdf.libs.hash_utils
import
compute_md5
from
magic_pdf.libs.pdf_image_tools
import
cut_image_to_pil_image
from
magic_pdf.libs.pdf_image_tools
import
cut_image_to_pil_image
from
magic_pdf.libs.performance_stats
import
measure_time
,
PerformanceStats
from
magic_pdf.model.magic_model
import
MagicModel
from
magic_pdf.model.magic_model
import
MagicModel
from
magic_pdf.post_proc.llm_aided
import
llm_aided_formula
,
llm_aided_text
,
llm_aided_title
from
magic_pdf.post_proc.llm_aided
import
llm_aided_formula
,
llm_aided_text
,
llm_aided_title
from
concurrent.futures
import
ThreadPoolExecutor
try
:
try
:
import
torchtext
import
torchtext
...
@@ -215,7 +218,7 @@ def calculate_contrast(img, img_mode) -> float:
...
@@ -215,7 +218,7 @@ def calculate_contrast(img, img_mode) -> float:
# logger.info(f"contrast: {contrast}")
# logger.info(f"contrast: {contrast}")
return
round
(
contrast
,
2
)
return
round
(
contrast
,
2
)
# @measure_time
def
txt_spans_extract_v2
(
pdf_page
,
spans
,
all_bboxes
,
all_discarded_blocks
,
lang
):
def
txt_spans_extract_v2
(
pdf_page
,
spans
,
all_bboxes
,
all_discarded_blocks
,
lang
):
# cid用0xfffd表示,连字符拆开
# cid用0xfffd表示,连字符拆开
# text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
# text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
...
@@ -489,7 +492,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
...
@@ -489,7 +492,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
else
:
else
:
return
[[
x0
,
y0
,
x1
,
y1
]]
return
[[
x0
,
y0
,
x1
,
y1
]]
# @measure_time
def
sort_lines_by_model
(
fix_blocks
,
page_w
,
page_h
,
line_height
):
def
sort_lines_by_model
(
fix_blocks
,
page_w
,
page_h
,
line_height
):
page_line_list
=
[]
page_line_list
=
[]
...
@@ -923,7 +926,6 @@ def pdf_parse_union(
...
@@ -923,7 +926,6 @@ def pdf_parse_union(
magic_model
=
MagicModel
(
model_list
,
dataset
)
magic_model
=
MagicModel
(
model_list
,
dataset
)
"""根据输入的起始范围解析pdf"""
"""根据输入的起始范围解析pdf"""
# end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
end_page_id
=
(
end_page_id
=
(
end_page_id
end_page_id
if
end_page_id
is
not
None
and
end_page_id
>=
0
if
end_page_id
is
not
None
and
end_page_id
>=
0
...
@@ -960,6 +962,8 @@ def pdf_parse_union(
...
@@ -960,6 +962,8 @@ def pdf_parse_union(
)
)
pdf_info_dict
[
f
'page_
{
page_id
}
'
]
=
page_info
pdf_info_dict
[
f
'page_
{
page_id
}
'
]
=
page_info
# PerformanceStats.print_stats()
"""分段"""
"""分段"""
para_split
(
pdf_info_dict
)
para_split
(
pdf_info_dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment