Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
852ae370
Unverified
Commit
852ae370
authored
Jan 15, 2025
by
Xiaomeng Zhao
Committed by
GitHub
Jan 15, 2025
Browse files
Merge pull request #1550 from myhloli/dev
feat(model): improve batch analysis logic and support npu
parents
f405cc22
f3502226
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
154 additions
and
104 deletions
+154
-104
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+87
-85
magic_pdf/model/doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+66
-18
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+1
-1
No files found.
magic_pdf/model/batch_analyze.py
View file @
852ae370
...
...
@@ -7,17 +7,17 @@ from loguru import logger
from
PIL
import
Image
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.config.exceptions
import
CUDA_NOT_AVAILABLE
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.libs.config_reader
import
get_device
from
magic_pdf.model.doc_analyze_by_custom_model
import
ModelSingleton
#
from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
#
from magic_pdf.data.dataset import Dataset
#
from magic_pdf.libs.clean_memory import clean_memory
#
from magic_pdf.libs.config_reader import get_device
#
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from
magic_pdf.model.pdf_extract_kit
import
CustomPEKModel
from
magic_pdf.model.sub_modules.model_utils
import
(
clean_vram
,
crop_img
,
get_res_list_from_layout_res
)
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
from
magic_pdf.operators.models
import
InferenceResult
#
from magic_pdf.operators.models import InferenceResult
YOLO_LAYOUT_BASE_BATCH_SIZE
=
4
MFD_BASE_BATCH_SIZE
=
1
...
...
@@ -91,10 +91,12 @@ class BatchAnalyze:
images
,
batch_size
=
self
.
batch_ratio
*
MFR_BASE_BATCH_SIZE
,
)
mfr_count
=
0
for
image_index
in
range
(
len
(
images
)):
images_layout_res
[
image_index
]
+=
images_formula_list
[
image_index
]
mfr_count
+=
len
(
images_formula_list
[
image_index
])
logger
.
info
(
f
'mfr time:
{
round
(
time
.
time
()
-
mfr_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
'
f
'mfr time:
{
round
(
time
.
time
()
-
mfr_start_time
,
2
)
}
, image num:
{
mfr_count
}
'
)
# 清理显存
...
...
@@ -195,81 +197,81 @@ class BatchAnalyze:
return
images_layout_res
def
doc_batch_analyze
(
dataset
:
Dataset
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
batch_ratio
:
int
|
None
=
None
,
)
->
InferenceResult
:
"""Perform batch analysis on a document dataset.
Args:
dataset (Dataset): The dataset containing document pages to be analyzed.
ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
show_log (bool, optional): Flag to enable logging. Defaults to False.
start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
lang (str, optional): Language for OCR. Defaults to None.
layout_model (optional): Layout model to be used for analysis. Defaults to None.
formula_enable (optional): Flag to enable formula detection. Defaults to None.
table_enable (optional): Flag to enable table detection. Defaults to None.
batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
Raises:
CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
Returns:
InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
"""
if
not
torch
.
cuda
.
is_available
():
raise
CUDA_NOT_AVAILABLE
(
'batch analyze not support in CPU mode'
)
lang
=
None
if
lang
==
''
else
lang
# TODO: auto detect batch size
batch_ratio
=
1
if
batch_ratio
is
None
else
batch_ratio
end_page_id
=
end_page_id
if
end_page_id
else
len
(
dataset
)
model_manager
=
ModelSingleton
()
custom_model
:
CustomPEKModel
=
model_manager
.
get_model
(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
batch_model
=
BatchAnalyze
(
model
=
custom_model
,
batch_ratio
=
batch_ratio
)
model_json
=
[]
# batch analyze
images
=
[]
for
index
in
range
(
len
(
dataset
)):
if
start_page_id
<=
index
<=
end_page_id
:
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
'img'
])
analyze_result
=
batch_model
(
images
)
for
index
in
range
(
len
(
dataset
)):
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
page_width
=
img_dict
[
'width'
]
page_height
=
img_dict
[
'height'
]
if
start_page_id
<=
index
<=
end_page_id
:
result
=
analyze_result
.
pop
(
0
)
else
:
result
=
[]
page_info
=
{
'page_no'
:
index
,
'height'
:
page_height
,
'width'
:
page_width
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
model_json
.
append
(
page_dict
)
# TODO: clean memory when gpu memory is not enough
clean_memory_start_time
=
time
.
time
()
clean_memory
(
get_device
())
logger
.
info
(
f
'clean memory time:
{
round
(
time
.
time
()
-
clean_memory_start_time
,
2
)
}
'
)
return
InferenceResult
(
model_json
,
dataset
)
#
def doc_batch_analyze(
#
dataset: Dataset,
#
ocr: bool = False,
#
show_log: bool = False,
#
start_page_id=0,
#
end_page_id=None,
#
lang=None,
#
layout_model=None,
#
formula_enable=None,
#
table_enable=None,
#
batch_ratio: int | None = None,
#
) -> InferenceResult:
#
"""Perform batch analysis on a document dataset.
#
#
Args:
#
dataset (Dataset): The dataset containing document pages to be analyzed.
#
ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
#
show_log (bool, optional): Flag to enable logging. Defaults to False.
#
start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
#
end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
#
lang (str, optional): Language for OCR. Defaults to None.
#
layout_model (optional): Layout model to be used for analysis. Defaults to None.
#
formula_enable (optional): Flag to enable formula detection. Defaults to None.
#
table_enable (optional): Flag to enable table detection. Defaults to None.
#
batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
#
#
Raises:
#
CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
#
#
Returns:
#
InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
#
"""
#
#
if not torch.cuda.is_available():
#
raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
#
#
lang = None if lang == '' else lang
#
# TODO: auto detect batch size
#
batch_ratio = 1 if batch_ratio is None else batch_ratio
#
end_page_id = end_page_id if end_page_id else len(dataset)
#
#
model_manager = ModelSingleton()
#
custom_model: CustomPEKModel = model_manager.get_model(
#
ocr, show_log, lang, layout_model, formula_enable, table_enable
#
)
#
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
#
#
model_json = []
#
#
# batch analyze
#
images = []
#
for index in range(len(dataset)):
#
if start_page_id <= index <= end_page_id:
#
page_data = dataset.get_page(index)
#
img_dict = page_data.get_image()
#
images.append(img_dict['img'])
#
analyze_result = batch_model(images)
#
#
for index in range(len(dataset)):
#
page_data = dataset.get_page(index)
#
img_dict = page_data.get_image()
#
page_width = img_dict['width']
#
page_height = img_dict['height']
#
if start_page_id <= index <= end_page_id:
#
result = analyze_result.pop(0)
#
else:
#
result = []
#
#
page_info = {'page_no': index, 'height': page_height, 'width': page_width}
#
page_dict = {'layout_dets': result, 'page_info': page_info}
#
model_json.append(page_dict)
#
#
# TODO: clean memory when gpu memory is not enough
#
clean_memory_start_time = time.time()
#
clean_memory(get_device())
#
logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
#
#
return InferenceResult(model_json, dataset)
magic_pdf/model/doc_analyze_by_custom_model.py
View file @
852ae370
...
...
@@ -3,8 +3,12 @@ import time
# 关闭paddle的信号处理
import
paddle
import
torch
from
loguru
import
logger
from
magic_pdf.model.batch_analyze
import
BatchAnalyze
from
magic_pdf.model.sub_modules.model_utils
import
get_vram
paddle
.
disable_signal_handler
()
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
...
...
@@ -154,33 +158,77 @@ def doc_analyze(
table_enable
=
None
,
)
->
InferenceResult
:
end_page_id
=
end_page_id
if
end_page_id
else
len
(
dataset
)
model_manager
=
ModelSingleton
()
custom_model
=
model_manager
.
get_model
(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
batch_analyze
=
False
device
=
get_device
()
npu_support
=
False
if
str
(
device
).
startswith
(
"npu"
):
import
torch_npu
if
torch_npu
.
npu
.
is_available
():
npu_support
=
True
if
torch
.
cuda
.
is_available
()
and
device
!=
'cpu'
or
npu_support
:
gpu_memory
=
get_vram
(
device
)
if
gpu_memory
is
not
None
and
gpu_memory
>=
7
:
batch_ratio
=
int
((
gpu_memory
-
3
)
//
1.5
)
if
batch_ratio
>=
1
:
logger
.
info
(
f
'gpu_memory:
{
gpu_memory
}
GB, batch_ratio:
{
batch_ratio
}
'
)
batch_model
=
BatchAnalyze
(
model
=
custom_model
,
batch_ratio
=
batch_ratio
)
batch_analyze
=
True
model_json
=
[]
doc_analyze_start
=
time
.
time
()
if
end_page_id
is
None
:
end_page_id
=
len
(
dataset
)
for
index
in
range
(
len
(
dataset
)):
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
img
=
img_dict
[
'img'
]
page_width
=
img_dict
[
'width'
]
page_height
=
img_dict
[
'height'
]
if
start_page_id
<=
index
<=
end_page_id
:
page_start
=
time
.
time
()
result
=
custom_model
(
img
)
logger
.
info
(
f
'-----page_id :
{
index
}
, page total time:
{
round
(
time
.
time
()
-
page_start
,
2
)
}
-----'
)
else
:
result
=
[]
if
batch_analyze
:
# batch analyze
images
=
[]
for
index
in
range
(
len
(
dataset
)):
if
start_page_id
<=
index
<=
end_page_id
:
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
'img'
])
analyze_result
=
batch_model
(
images
)
for
index
in
range
(
len
(
dataset
)):
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
page_width
=
img_dict
[
'width'
]
page_height
=
img_dict
[
'height'
]
if
start_page_id
<=
index
<=
end_page_id
:
result
=
analyze_result
.
pop
(
0
)
else
:
result
=
[]
page_info
=
{
'page_no'
:
index
,
'height'
:
page_height
,
'width'
:
page_width
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
model_json
.
append
(
page_dict
)
page_info
=
{
'page_no'
:
index
,
'height'
:
page_height
,
'width'
:
page_width
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
model_json
.
append
(
page_dict
)
else
:
# single analyze
for
index
in
range
(
len
(
dataset
)):
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
img
=
img_dict
[
'img'
]
page_width
=
img_dict
[
'width'
]
page_height
=
img_dict
[
'height'
]
if
start_page_id
<=
index
<=
end_page_id
:
page_start
=
time
.
time
()
result
=
custom_model
(
img
)
logger
.
info
(
f
'-----page_id :
{
index
}
, page total time:
{
round
(
time
.
time
()
-
page_start
,
2
)
}
-----'
)
else
:
result
=
[]
page_info
=
{
'page_no'
:
index
,
'height'
:
page_height
,
'width'
:
page_width
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
model_json
.
append
(
page_dict
)
gc_start
=
time
.
time
()
clean_memory
(
get_device
())
...
...
magic_pdf/model/pdf_extract_kit.py
View file @
852ae370
...
...
@@ -228,7 +228,7 @@ class CustomPEKModel:
logger
.
info
(
f
'formula nums:
{
len
(
formula_list
)
}
, mfr time:
{
mfr_cost
}
'
)
# 清理显存
clean_vram
(
self
.
device
,
vram_threshold
=
8
)
clean_vram
(
self
.
device
,
vram_threshold
=
6
)
# 从layout_res中获取ocr区域、表格区域、公式区域
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment