Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
7ce9edc6
Commit
7ce9edc6
authored
Dec 12, 2024
by
Suven
Browse files
feat: add batch prediction methods for YOLOv8 and Unimernet models
parent
8f266869
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
364 additions
and
36 deletions
+364
-36
magic_pdf/config/exceptions.py
magic_pdf/config/exceptions.py
+7
-0
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+228
-0
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py
.../model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py
+41
-7
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
+18
-2
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
+70
-27
No files found.
magic_pdf/config/exceptions.py
View file @
7ce9edc6
...
...
@@ -30,3 +30,10 @@ class EmptyData(Exception):
def
__str__
(
self
):
return
f
'Empty data:
{
self
.
msg
}
'
class
CUDA_NOT_AVAILABLE
(
Exception
):
def
__init__
(
self
,
msg
):
self
.
msg
=
msg
def
__str__
(
self
):
return
f
'CUDA not available:
{
self
.
msg
}
'
\ No newline at end of file
magic_pdf/model/batch_analyze.py
0 → 100644
View file @
7ce9edc6
import
time
import
cv2
import
numpy
as
np
import
torch
from
loguru
import
logger
from
PIL
import
Image
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.config.exceptions
import
CUDA_NOT_AVAILABLE
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.model.doc_analyze_by_custom_model
import
ModelSingleton
from
magic_pdf.model.operators
import
InferenceResult
from
magic_pdf.model.pdf_extract_kit
import
CustomPEKModel
from
magic_pdf.model.sub_modules.model_utils
import
(
clean_vram
,
crop_img
,
get_res_list_from_layout_res
,
)
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
get_adjusted_mfdetrec_res
,
get_ocr_result_list
,
)
YOLO_LAYOUT_BASE_BATCH_SIZE
=
4
MFD_BASE_BATCH_SIZE
=
1
MFR_BASE_BATCH_SIZE
=
16
class
BatchAnalyze
:
def
__init__
(
self
,
model
:
CustomPEKModel
,
batch_ratio
:
int
):
self
.
model
=
model
self
.
batch_ratio
=
batch_ratio
def
__call__
(
self
,
images
:
list
)
->
list
:
if
self
.
model
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
# layoutlmv3
images_layout_res
=
[]
for
image
in
images
:
layout_res
=
self
.
model
.
layout_model
(
image
,
ignore_catids
=
[])
images_layout_res
.
append
(
layout_res
)
elif
self
.
model
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
# doclayout_yolo
images_layout_res
=
self
.
model
.
layout_model
.
batch_predict
(
images
,
self
.
batch_ratio
*
YOLO_LAYOUT_BASE_BATCH_SIZE
)
if
self
.
model
.
apply_formula
:
# 公式检测
images_mfd_res
=
self
.
model
.
mfd_model
.
batch_predict
(
images
,
self
.
batch_ratio
*
MFD_BASE_BATCH_SIZE
)
# 公式识别
images_formula_list
=
self
.
model
.
mfr_model
.
batch_predict
(
images_mfd_res
,
images
,
batch_size
=
self
.
batch_ratio
*
MFR_BASE_BATCH_SIZE
,
)
for
image_index
in
range
(
len
(
images
)):
images_layout_res
[
image_index
]
+=
images_formula_list
[
image_index
]
# 清理显存
clean_vram
(
self
.
model
.
device
,
vram_threshold
=
8
)
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for
index
in
range
(
len
(
images
)):
layout_res
=
images_layout_res
[
index
]
pil_img
=
Image
.
fromarray
(
images
[
index
])
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
=
(
get_res_list_from_layout_res
(
layout_res
)
)
# ocr识别
ocr_start
=
time
.
time
()
# Process each area that requires OCR processing
for
res
in
ocr_res_list
:
new_image
,
useful_list
=
crop_img
(
res
,
pil_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
# OCR recognition
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
),
cv2
.
COLOR_RGB2BGR
)
if
self
.
model
.
apply_ocr
:
ocr_res
=
self
.
model
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
else
:
ocr_res
=
self
.
model
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
,
rec
=
False
)[
0
]
# Integration results
if
ocr_res
:
ocr_result_list
=
get_ocr_result_list
(
ocr_res
,
useful_list
)
layout_res
.
extend
(
ocr_result_list
)
ocr_cost
=
round
(
time
.
time
()
-
ocr_start
,
2
)
if
self
.
model
.
apply_ocr
:
logger
.
info
(
f
"ocr time:
{
ocr_cost
}
"
)
else
:
logger
.
info
(
f
"det time:
{
ocr_cost
}
"
)
# 表格识别 table recognition
if
self
.
model
.
apply_table
:
table_start
=
time
.
time
()
for
res
in
table_res_list
:
new_image
,
_
=
crop_img
(
res
,
pil_img
)
single_table_start_time
=
time
.
time
()
html_code
=
None
if
self
.
model
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
with
torch
.
no_grad
():
table_result
=
self
.
model
.
table_model
.
predict
(
new_image
,
"html"
)
if
len
(
table_result
)
>
0
:
html_code
=
table_result
[
0
]
elif
self
.
model
.
table_model_name
==
MODEL_NAME
.
TABLE_MASTER
:
html_code
=
self
.
model
.
table_model
.
img2html
(
new_image
)
elif
self
.
model
.
table_model_name
==
MODEL_NAME
.
RAPID_TABLE
:
html_code
,
table_cell_bboxes
,
elapse
=
(
self
.
model
.
table_model
.
predict
(
new_image
)
)
run_time
=
time
.
time
()
-
single_table_start_time
if
run_time
>
self
.
model
.
table_max_time
:
logger
.
warning
(
f
"table recognition processing exceeds max time
{
self
.
model
.
table_max_time
}
s"
)
# 判断是否返回正常
if
html_code
:
expected_ending
=
html_code
.
strip
().
endswith
(
"</html>"
)
or
html_code
.
strip
().
endswith
(
"</table>"
)
if
expected_ending
:
res
[
"html"
]
=
html_code
else
:
logger
.
warning
(
"table recognition processing fails, not found expected HTML table end"
)
else
:
logger
.
warning
(
"table recognition processing fails, not get html return"
)
logger
.
info
(
f
"table time:
{
round
(
time
.
time
()
-
table_start
,
2
)
}
"
)
def
doc_batch_analyze
(
dataset
:
Dataset
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
batch_ratio
:
int
|
None
=
None
,
)
->
InferenceResult
:
"""
Perform batch analysis on a document dataset.
Args:
dataset (Dataset): The dataset containing document pages to be analyzed.
ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
show_log (bool, optional): Flag to enable logging. Defaults to False.
start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
lang (str, optional): Language for OCR. Defaults to None.
layout_model (optional): Layout model to be used for analysis. Defaults to None.
formula_enable (optional): Flag to enable formula detection. Defaults to None.
table_enable (optional): Flag to enable table detection. Defaults to None.
batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
Raises:
CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
Returns:
InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
"""
if
not
torch
.
cuda
.
is_available
():
raise
CUDA_NOT_AVAILABLE
(
"batch analyze not support in CPU mode"
)
lang
=
None
if
lang
==
""
else
lang
# TODO: auto detect batch size
batch_ratio
=
1
if
batch_ratio
is
None
else
batch_ratio
end_page_id
=
end_page_id
if
end_page_id
else
len
(
dataset
)
model_manager
=
ModelSingleton
()
custom_model
:
CustomPEKModel
=
model_manager
.
get_model
(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
batch_model
=
BatchAnalyze
(
model
=
custom_model
,
batch_ratio
=
batch_ratio
)
model_json
=
[]
# batch analyze
images
=
[]
for
index
in
range
(
len
(
dataset
)):
if
start_page_id
<=
index
<=
end_page_id
:
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
"img"
])
analyze_result
=
batch_model
(
images
)
for
index
in
range
(
len
(
dataset
)):
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
page_width
=
img_dict
[
"width"
]
page_height
=
img_dict
[
"height"
]
if
start_page_id
<=
index
<=
end_page_id
:
result
=
analyze_result
.
pop
(
0
)
else
:
result
=
[]
page_info
=
{
"page_no"
:
index
,
"height"
:
page_height
,
"width"
:
page_width
}
page_dict
=
{
"layout_dets"
:
result
,
"page_info"
:
page_info
}
model_json
.
append
(
page_dict
)
# TODO: clean memory when gpu memory is not enough
clean_memory
()
return
InferenceResult
(
model_json
,
dataset
)
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py
View file @
7ce9edc6
...
...
@@ -8,14 +8,48 @@ class DocLayoutYOLOModel(object):
def
predict
(
self
,
image
):
layout_res
=
[]
doclayout_yolo_res
=
self
.
model
.
predict
(
image
,
imgsz
=
1024
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
)[
0
]
for
xyxy
,
conf
,
cla
in
zip
(
doclayout_yolo_res
.
boxes
.
xyxy
.
cpu
(),
doclayout_yolo_res
.
boxes
.
conf
.
cpu
(),
doclayout_yolo_res
.
boxes
.
cls
.
cpu
()):
doclayout_yolo_res
=
self
.
model
.
predict
(
image
,
imgsz
=
1024
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
)[
0
]
for
xyxy
,
conf
,
cla
in
zip
(
doclayout_yolo_res
.
boxes
.
xyxy
.
cpu
(),
doclayout_yolo_res
.
boxes
.
conf
.
cpu
(),
doclayout_yolo_res
.
boxes
.
cls
.
cpu
(),
):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
new_item
=
{
'
category_id
'
:
int
(
cla
.
item
()),
'
poly
'
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
'
score
'
:
round
(
float
(
conf
.
item
()),
3
),
"
category_id
"
:
int
(
cla
.
item
()),
"
poly
"
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
"
score
"
:
round
(
float
(
conf
.
item
()),
3
),
}
layout_res
.
append
(
new_item
)
return
layout_res
\ No newline at end of file
return
layout_res
def
batch_predict
(
self
,
images
:
list
,
batch_size
:
int
)
->
list
:
images_layout_res
=
[]
for
index
in
range
(
0
,
len
(
images
),
batch_size
):
doclayout_yolo_res
=
self
.
model
.
predict
(
images
[
index
:
index
+
batch_size
],
imgsz
=
1024
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
,
).
cpu
()
for
image_res
in
doclayout_yolo_res
:
layout_res
=
[]
for
xyxy
,
conf
,
cla
in
zip
(
image_res
.
boxes
.
xyxy
,
image_res
.
boxes
.
conf
,
image_res
.
boxes
.
cls
,
):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
new_item
=
{
"category_id"
:
int
(
cla
.
item
()),
"poly"
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
"score"
:
round
(
float
(
conf
.
item
()),
3
),
}
layout_res
.
append
(
new_item
)
images_layout_res
.
append
(
layout_res
)
return
images_layout_res
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
View file @
7ce9edc6
...
...
@@ -2,11 +2,27 @@ from ultralytics import YOLO
class
YOLOv8MFDModel
(
object
):
def
__init__
(
self
,
weight
,
device
=
'
cpu
'
):
def
__init__
(
self
,
weight
,
device
=
"
cpu
"
):
self
.
mfd_model
=
YOLO
(
weight
)
self
.
device
=
device
def
predict
(
self
,
image
):
mfd_res
=
self
.
mfd_model
.
predict
(
image
,
imgsz
=
1888
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
)[
0
]
mfd_res
=
self
.
mfd_model
.
predict
(
image
,
imgsz
=
1888
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
)[
0
]
return
mfd_res
def
batch_predict
(
self
,
images
:
list
,
batch_size
:
int
)
->
list
:
images_mfd_res
=
[]
for
index
in
range
(
0
,
len
(
images
),
batch_size
):
mfd_res
=
self
.
mfd_model
.
predict
(
images
[
index
:
index
+
batch_size
],
imgsz
=
1888
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
,
).
cpu
()
for
image_res
in
mfd_res
:
images_mfd_res
.
append
(
image_res
)
return
images_mfd_res
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
View file @
7ce9edc6
import
os
import
argparse
import
os
import
re
from
PIL
import
Image
import
torch
from
torch.utils.data
import
Dataset
,
DataLoader
import
unimernet.tasks
as
tasks
from
PIL
import
Image
from
torch.utils.data
import
DataLoader
,
Dataset
from
torchvision
import
transforms
from
unimernet.common.config
import
Config
import
unimernet.tasks
as
tasks
from
unimernet.processors
import
load_processor
...
...
@@ -31,27 +31,25 @@ class MathDataset(Dataset):
def
latex_rm_whitespace
(
s
:
str
):
"""Remove unnecessary whitespace from LaTeX code.
"""
text_reg
=
r
'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
letter
=
'[a-zA-Z]'
noletter
=
'[\W_^\d]'
names
=
[
x
[
0
].
replace
(
' '
,
''
)
for
x
in
re
.
findall
(
text_reg
,
s
)]
"""Remove unnecessary whitespace from LaTeX code."""
text_reg
=
r
"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
letter
=
"[a-zA-Z]"
noletter
=
"[\W_^\d]"
names
=
[
x
[
0
].
replace
(
" "
,
""
)
for
x
in
re
.
findall
(
text_reg
,
s
)]
s
=
re
.
sub
(
text_reg
,
lambda
match
:
str
(
names
.
pop
(
0
)),
s
)
news
=
s
while
True
:
s
=
news
news
=
re
.
sub
(
r
'
(?!\\ )(%s)\s+?(%s)
'
%
(
noletter
,
noletter
),
r
'
\1\2
'
,
s
)
news
=
re
.
sub
(
r
'
(?!\\ )(%s)\s+?(%s)
'
%
(
noletter
,
letter
),
r
'
\1\2
'
,
news
)
news
=
re
.
sub
(
r
'
(%s)\s+?(%s)
'
%
(
letter
,
noletter
),
r
'
\1\2
'
,
news
)
news
=
re
.
sub
(
r
"
(?!\\ )(%s)\s+?(%s)
"
%
(
noletter
,
noletter
),
r
"
\1\2
"
,
s
)
news
=
re
.
sub
(
r
"
(?!\\ )(%s)\s+?(%s)
"
%
(
noletter
,
letter
),
r
"
\1\2
"
,
news
)
news
=
re
.
sub
(
r
"
(%s)\s+?(%s)
"
%
(
letter
,
noletter
),
r
"
\1\2
"
,
news
)
if
news
==
s
:
break
return
s
class
UnimernetModel
(
object
):
def
__init__
(
self
,
weight_dir
,
cfg_path
,
_device_
=
'cpu'
):
def
__init__
(
self
,
weight_dir
,
cfg_path
,
_device_
=
"cpu"
):
args
=
argparse
.
Namespace
(
cfg_path
=
cfg_path
,
options
=
None
)
cfg
=
Config
(
args
)
cfg
.
config
.
model
.
pretrained
=
os
.
path
.
join
(
weight_dir
,
"pytorch_model.pth"
)
...
...
@@ -62,20 +60,28 @@ class UnimernetModel(object):
self
.
device
=
_device_
self
.
model
.
to
(
_device_
)
self
.
model
.
eval
()
vis_processor
=
load_processor
(
'formula_image_eval'
,
cfg
.
config
.
datasets
.
formula_rec_eval
.
vis_processor
.
eval
)
self
.
mfr_transform
=
transforms
.
Compose
([
vis_processor
,
])
vis_processor
=
load_processor
(
"formula_image_eval"
,
cfg
.
config
.
datasets
.
formula_rec_eval
.
vis_processor
.
eval
,
)
self
.
mfr_transform
=
transforms
.
Compose
(
[
vis_processor
,
]
)
def
predict
(
self
,
mfd_res
,
image
):
formula_list
=
[]
mf_image_list
=
[]
for
xyxy
,
conf
,
cla
in
zip
(
mfd_res
.
boxes
.
xyxy
.
cpu
(),
mfd_res
.
boxes
.
conf
.
cpu
(),
mfd_res
.
boxes
.
cls
.
cpu
()):
for
xyxy
,
conf
,
cla
in
zip
(
mfd_res
.
boxes
.
xyxy
.
cpu
(),
mfd_res
.
boxes
.
conf
.
cpu
(),
mfd_res
.
boxes
.
cls
.
cpu
()
):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
new_item
=
{
'
category_id
'
:
13
+
int
(
cla
.
item
()),
'
poly
'
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
'
score
'
:
round
(
float
(
conf
.
item
()),
2
),
'
latex
'
:
''
,
"
category_id
"
:
13
+
int
(
cla
.
item
()),
"
poly
"
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
"
score
"
:
round
(
float
(
conf
.
item
()),
2
),
"
latex
"
:
""
,
}
formula_list
.
append
(
new_item
)
pil_img
=
Image
.
fromarray
(
image
)
...
...
@@ -88,11 +94,48 @@ class UnimernetModel(object):
for
mf_img
in
dataloader
:
mf_img
=
mf_img
.
to
(
self
.
device
)
with
torch
.
no_grad
():
output
=
self
.
model
.
generate
({
'
image
'
:
mf_img
})
mfr_res
.
extend
(
output
[
'
pred_str
'
])
output
=
self
.
model
.
generate
({
"
image
"
:
mf_img
})
mfr_res
.
extend
(
output
[
"
pred_str
"
])
for
res
,
latex
in
zip
(
formula_list
,
mfr_res
):
res
[
'
latex
'
]
=
latex_rm_whitespace
(
latex
)
res
[
"
latex
"
]
=
latex_rm_whitespace
(
latex
)
return
formula_list
def
batch_predict
(
self
,
images_mfd_res
:
list
,
images
:
list
,
batch_size
:
int
=
64
)
->
list
:
images_formula_list
=
[]
mf_image_list
=
[]
backfill_list
=
[]
for
image_index
in
range
(
len
(
images_mfd_res
)):
mfd_res
=
images_mfd_res
[
image_index
]
pil_img
=
Image
.
fromarray
(
images
[
image_index
])
formula_list
=
[]
for
xyxy
,
conf
,
cla
in
zip
(
mfd_res
.
boxes
.
xyxy
,
mfd_res
.
boxes
.
conf
,
mfd_res
.
boxes
.
cls
):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
new_item
=
{
"category_id"
:
13
+
int
(
cla
.
item
()),
"poly"
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
"score"
:
round
(
float
(
conf
.
item
()),
2
),
"latex"
:
""
,
}
formula_list
.
append
(
new_item
)
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
ymax
))
mf_image_list
.
append
(
bbox_img
)
images_formula_list
.
append
(
formula_list
)
backfill_list
+=
formula_list
dataset
=
MathDataset
(
mf_image_list
,
transform
=
self
.
mfr_transform
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
num_workers
=
0
)
mfr_res
=
[]
for
mf_img
in
dataloader
:
mf_img
=
mf_img
.
to
(
self
.
device
)
with
torch
.
no_grad
():
output
=
self
.
model
.
generate
({
"image"
:
mf_img
})
mfr_res
.
extend
(
output
[
"pred_str"
])
for
res
,
latex
in
zip
(
backfill_list
,
mfr_res
):
res
[
"latex"
]
=
latex_rm_whitespace
(
latex
)
return
images_formula_list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment