Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
852ae370
Unverified
Commit
852ae370
authored
Jan 15, 2025
by
Xiaomeng Zhao
Committed by
GitHub
Jan 15, 2025
Browse files
Merge pull request #1550 from myhloli/dev
feat(model): improve batch analysis logic and support npu
parents
f405cc22
f3502226
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
154 additions
and
104 deletions
+154
-104
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+87
-85
magic_pdf/model/doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+66
-18
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+1
-1
No files found.
magic_pdf/model/batch_analyze.py
View file @
852ae370
...
@@ -7,17 +7,17 @@ from loguru import logger
...
@@ -7,17 +7,17 @@ from loguru import logger
from
PIL
import
Image
from
PIL
import
Image
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.config.exceptions
import
CUDA_NOT_AVAILABLE
#
from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
from
magic_pdf.data.dataset
import
Dataset
#
from magic_pdf.data.dataset import Dataset
from
magic_pdf.libs.clean_memory
import
clean_memory
#
from magic_pdf.libs.clean_memory import clean_memory
from
magic_pdf.libs.config_reader
import
get_device
#
from magic_pdf.libs.config_reader import get_device
from
magic_pdf.model.doc_analyze_by_custom_model
import
ModelSingleton
#
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from
magic_pdf.model.pdf_extract_kit
import
CustomPEKModel
from
magic_pdf.model.pdf_extract_kit
import
CustomPEKModel
from
magic_pdf.model.sub_modules.model_utils
import
(
from
magic_pdf.model.sub_modules.model_utils
import
(
clean_vram
,
crop_img
,
get_res_list_from_layout_res
)
clean_vram
,
crop_img
,
get_res_list_from_layout_res
)
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
from
magic_pdf.operators.models
import
InferenceResult
#
from magic_pdf.operators.models import InferenceResult
YOLO_LAYOUT_BASE_BATCH_SIZE
=
4
YOLO_LAYOUT_BASE_BATCH_SIZE
=
4
MFD_BASE_BATCH_SIZE
=
1
MFD_BASE_BATCH_SIZE
=
1
...
@@ -91,10 +91,12 @@ class BatchAnalyze:
...
@@ -91,10 +91,12 @@ class BatchAnalyze:
images
,
images
,
batch_size
=
self
.
batch_ratio
*
MFR_BASE_BATCH_SIZE
,
batch_size
=
self
.
batch_ratio
*
MFR_BASE_BATCH_SIZE
,
)
)
mfr_count
=
0
for
image_index
in
range
(
len
(
images
)):
for
image_index
in
range
(
len
(
images
)):
images_layout_res
[
image_index
]
+=
images_formula_list
[
image_index
]
images_layout_res
[
image_index
]
+=
images_formula_list
[
image_index
]
mfr_count
+=
len
(
images_formula_list
[
image_index
])
logger
.
info
(
logger
.
info
(
f
'mfr time:
{
round
(
time
.
time
()
-
mfr_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
'
f
'mfr time:
{
round
(
time
.
time
()
-
mfr_start_time
,
2
)
}
, image num:
{
mfr_count
}
'
)
)
# 清理显存
# 清理显存
...
@@ -195,81 +197,81 @@ class BatchAnalyze:
...
@@ -195,81 +197,81 @@ class BatchAnalyze:
return
images_layout_res
return
images_layout_res
def
doc_batch_analyze
(
#
def doc_batch_analyze(
dataset
:
Dataset
,
#
dataset: Dataset,
ocr
:
bool
=
False
,
#
ocr: bool = False,
show_log
:
bool
=
False
,
#
show_log: bool = False,
start_page_id
=
0
,
#
start_page_id=0,
end_page_id
=
None
,
#
end_page_id=None,
lang
=
None
,
#
lang=None,
layout_model
=
None
,
#
layout_model=None,
formula_enable
=
None
,
#
formula_enable=None,
table_enable
=
None
,
#
table_enable=None,
batch_ratio
:
int
|
None
=
None
,
#
batch_ratio: int | None = None,
)
->
InferenceResult
:
#
) -> InferenceResult:
"""Perform batch analysis on a document dataset.
#
"""Perform batch analysis on a document dataset.
#
Args:
#
Args:
dataset (Dataset): The dataset containing document pages to be analyzed.
#
dataset (Dataset): The dataset containing document pages to be analyzed.
ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
#
ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
show_log (bool, optional): Flag to enable logging. Defaults to False.
#
show_log (bool, optional): Flag to enable logging. Defaults to False.
start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
#
start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
#
end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
lang (str, optional): Language for OCR. Defaults to None.
#
lang (str, optional): Language for OCR. Defaults to None.
layout_model (optional): Layout model to be used for analysis. Defaults to None.
#
layout_model (optional): Layout model to be used for analysis. Defaults to None.
formula_enable (optional): Flag to enable formula detection. Defaults to None.
#
formula_enable (optional): Flag to enable formula detection. Defaults to None.
table_enable (optional): Flag to enable table detection. Defaults to None.
#
table_enable (optional): Flag to enable table detection. Defaults to None.
batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
#
batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
#
Raises:
#
Raises:
CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
#
CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
#
Returns:
#
Returns:
InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
#
InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
"""
#
"""
#
if
not
torch
.
cuda
.
is_available
():
#
if not torch.cuda.is_available():
raise
CUDA_NOT_AVAILABLE
(
'batch analyze not support in CPU mode'
)
#
raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
#
lang
=
None
if
lang
==
''
else
lang
#
lang = None if lang == '' else lang
# TODO: auto detect batch size
#
# TODO: auto detect batch size
batch_ratio
=
1
if
batch_ratio
is
None
else
batch_ratio
#
batch_ratio = 1 if batch_ratio is None else batch_ratio
end_page_id
=
end_page_id
if
end_page_id
else
len
(
dataset
)
#
end_page_id = end_page_id if end_page_id else len(dataset)
#
model_manager
=
ModelSingleton
()
#
model_manager = ModelSingleton()
custom_model
:
CustomPEKModel
=
model_manager
.
get_model
(
#
custom_model: CustomPEKModel = model_manager.get_model(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
#
ocr, show_log, lang, layout_model, formula_enable, table_enable
)
#
)
batch_model
=
BatchAnalyze
(
model
=
custom_model
,
batch_ratio
=
batch_ratio
)
#
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
#
model_json
=
[]
#
model_json = []
#
# batch analyze
#
# batch analyze
images
=
[]
#
images = []
for
index
in
range
(
len
(
dataset
)):
#
for index in range(len(dataset)):
if
start_page_id
<=
index
<=
end_page_id
:
#
if start_page_id <= index <= end_page_id:
page_data
=
dataset
.
get_page
(
index
)
#
page_data = dataset.get_page(index)
img_dict
=
page_data
.
get_image
()
#
img_dict = page_data.get_image()
images
.
append
(
img_dict
[
'img'
])
#
images.append(img_dict['img'])
analyze_result
=
batch_model
(
images
)
#
analyze_result = batch_model(images)
#
for
index
in
range
(
len
(
dataset
)):
#
for index in range(len(dataset)):
page_data
=
dataset
.
get_page
(
index
)
#
page_data = dataset.get_page(index)
img_dict
=
page_data
.
get_image
()
#
img_dict = page_data.get_image()
page_width
=
img_dict
[
'width'
]
#
page_width = img_dict['width']
page_height
=
img_dict
[
'height'
]
#
page_height = img_dict['height']
if
start_page_id
<=
index
<=
end_page_id
:
#
if start_page_id <= index <= end_page_id:
result
=
analyze_result
.
pop
(
0
)
#
result = analyze_result.pop(0)
else
:
#
else:
result
=
[]
#
result = []
#
page_info
=
{
'page_no'
:
index
,
'height'
:
page_height
,
'width'
:
page_width
}
#
page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
#
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json
.
append
(
page_dict
)
#
model_json.append(page_dict)
#
# TODO: clean memory when gpu memory is not enough
#
# TODO: clean memory when gpu memory is not enough
clean_memory_start_time
=
time
.
time
()
#
clean_memory_start_time = time.time()
clean_memory
(
get_device
())
#
clean_memory(get_device())
logger
.
info
(
f
'clean memory time:
{
round
(
time
.
time
()
-
clean_memory_start_time
,
2
)
}
'
)
#
logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
#
return
InferenceResult
(
model_json
,
dataset
)
#
return InferenceResult(model_json, dataset)
magic_pdf/model/doc_analyze_by_custom_model.py
View file @
852ae370
...
@@ -3,8 +3,12 @@ import time
...
@@ -3,8 +3,12 @@ import time
# 关闭paddle的信号处理
# 关闭paddle的信号处理
import
paddle
import
paddle
import
torch
from
loguru
import
logger
from
loguru
import
logger
from
magic_pdf.model.batch_analyze
import
BatchAnalyze
from
magic_pdf.model.sub_modules.model_utils
import
get_vram
paddle
.
disable_signal_handler
()
paddle
.
disable_signal_handler
()
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
...
@@ -154,33 +158,77 @@ def doc_analyze(
...
@@ -154,33 +158,77 @@ def doc_analyze(
table_enable
=
None
,
table_enable
=
None
,
)
->
InferenceResult
:
)
->
InferenceResult
:
end_page_id
=
end_page_id
if
end_page_id
else
len
(
dataset
)
model_manager
=
ModelSingleton
()
model_manager
=
ModelSingleton
()
custom_model
=
model_manager
.
get_model
(
custom_model
=
model_manager
.
get_model
(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
)
batch_analyze
=
False
device
=
get_device
()
npu_support
=
False
if
str
(
device
).
startswith
(
"npu"
):
import
torch_npu
if
torch_npu
.
npu
.
is_available
():
npu_support
=
True
if
torch
.
cuda
.
is_available
()
and
device
!=
'cpu'
or
npu_support
:
gpu_memory
=
get_vram
(
device
)
if
gpu_memory
is
not
None
and
gpu_memory
>=
7
:
batch_ratio
=
int
((
gpu_memory
-
3
)
//
1.5
)
if
batch_ratio
>=
1
:
logger
.
info
(
f
'gpu_memory:
{
gpu_memory
}
GB, batch_ratio:
{
batch_ratio
}
'
)
batch_model
=
BatchAnalyze
(
model
=
custom_model
,
batch_ratio
=
batch_ratio
)
batch_analyze
=
True
model_json
=
[]
model_json
=
[]
doc_analyze_start
=
time
.
time
()
doc_analyze_start
=
time
.
time
()
if
end_page_id
is
None
:
if
batch_analyze
:
end_page_id
=
len
(
dataset
)
# batch analyze
images
=
[]
for
index
in
range
(
len
(
dataset
)):
for
index
in
range
(
len
(
dataset
)):
page_data
=
dataset
.
get_page
(
index
)
if
start_page_id
<=
index
<=
end_page_id
:
img_dict
=
page_data
.
get_image
()
page_data
=
dataset
.
get_page
(
index
)
img
=
img_dict
[
'img'
]
img_dict
=
page_data
.
get_image
()
page_width
=
img_dict
[
'width'
]
images
.
append
(
img_dict
[
'img'
])
page_height
=
img_dict
[
'height'
]
analyze_result
=
batch_model
(
images
)
if
start_page_id
<=
index
<=
end_page_id
:
page_start
=
time
.
time
()
for
index
in
range
(
len
(
dataset
)):
result
=
custom_model
(
img
)
page_data
=
dataset
.
get_page
(
index
)
logger
.
info
(
f
'-----page_id :
{
index
}
, page total time:
{
round
(
time
.
time
()
-
page_start
,
2
)
}
-----'
)
img_dict
=
page_data
.
get_image
()
else
:
page_width
=
img_dict
[
'width'
]
result
=
[]
page_height
=
img_dict
[
'height'
]
if
start_page_id
<=
index
<=
end_page_id
:
result
=
analyze_result
.
pop
(
0
)
else
:
result
=
[]
page_info
=
{
'page_no'
:
index
,
'height'
:
page_height
,
'width'
:
page_width
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
model_json
.
append
(
page_dict
)
page_info
=
{
'page_no'
:
index
,
'height'
:
page_height
,
'width'
:
page_width
}
else
:
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
# single analyze
model_json
.
append
(
page_dict
)
for
index
in
range
(
len
(
dataset
)):
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
img
=
img_dict
[
'img'
]
page_width
=
img_dict
[
'width'
]
page_height
=
img_dict
[
'height'
]
if
start_page_id
<=
index
<=
end_page_id
:
page_start
=
time
.
time
()
result
=
custom_model
(
img
)
logger
.
info
(
f
'-----page_id :
{
index
}
, page total time:
{
round
(
time
.
time
()
-
page_start
,
2
)
}
-----'
)
else
:
result
=
[]
page_info
=
{
'page_no'
:
index
,
'height'
:
page_height
,
'width'
:
page_width
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
model_json
.
append
(
page_dict
)
gc_start
=
time
.
time
()
gc_start
=
time
.
time
()
clean_memory
(
get_device
())
clean_memory
(
get_device
())
...
...
magic_pdf/model/pdf_extract_kit.py
View file @
852ae370
...
@@ -228,7 +228,7 @@ class CustomPEKModel:
...
@@ -228,7 +228,7 @@ class CustomPEKModel:
logger
.
info
(
f
'formula nums:
{
len
(
formula_list
)
}
, mfr time:
{
mfr_cost
}
'
)
logger
.
info
(
f
'formula nums:
{
len
(
formula_list
)
}
, mfr time:
{
mfr_cost
}
'
)
# 清理显存
# 清理显存
clean_vram
(
self
.
device
,
vram_threshold
=
8
)
clean_vram
(
self
.
device
,
vram_threshold
=
6
)
# 从layout_res中获取ocr区域、表格区域、公式区域
# 从layout_res中获取ocr区域、表格区域、公式区域
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
=
(
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment