Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
a5e41791
Commit
a5e41791
authored
Jul 22, 2025
by
myhloli
Browse files
fix: adjust batch sizes and improve performance settings in various modules
parent
f7f35189
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
17 additions
and
10 deletions
+17
-10
mineru/backend/pipeline/batch_analyze.py
mineru/backend/pipeline/batch_analyze.py
+4
-3
mineru/backend/pipeline/pipeline_analyze.py
mineru/backend/pipeline/pipeline_analyze.py
+3
-3
mineru/model/mfr/unimernet/Unimernet.py
mineru/model/mfr/unimernet/Unimernet.py
+1
-1
mineru/model/mfr/unimernet/unimernet_hf/modeling_unimernet.py
...ru/model/mfr/unimernet/unimernet_hf/modeling_unimernet.py
+8
-2
mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py
mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py
+1
-1
No files found.
mineru/backend/pipeline/batch_analyze.py
View file @
a5e41791
...
...
@@ -12,6 +12,7 @@ from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, O
YOLO_LAYOUT_BASE_BATCH_SIZE
=
8
MFD_BASE_BATCH_SIZE
=
1
MFR_BASE_BATCH_SIZE
=
16
OCR_DET_BASE_BATCH_SIZE
=
16
class
BatchAnalyze
:
...
...
@@ -170,9 +171,9 @@ class BatchAnalyze:
batch_images
.
append
(
padded_img
)
# 批处理检测
batch_size
=
min
(
len
(
batch_images
),
self
.
batch_ratio
*
16
)
# 增加批处理大小
# logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
batch_results
=
ocr_model
.
text_detector
.
batch_predict
(
batch_images
,
batch_size
)
det_
batch_size
=
min
(
len
(
batch_images
),
self
.
batch_ratio
*
OCR_DET_BASE_BATCH_SIZE
)
# 增加批处理大小
# logger.debug(f"OCR-det batch: {
det_
batch_size} images, target size: {target_h}x{target_w}")
batch_results
=
ocr_model
.
text_detector
.
batch_predict
(
batch_images
,
det_
batch_size
)
# 处理批处理结果
for
i
,
(
crop_info
,
(
dt_boxes
,
elapse
))
in
enumerate
(
zip
(
group_crops
,
batch_results
)):
...
...
mineru/backend/pipeline/pipeline_analyze.py
View file @
a5e41791
...
...
@@ -74,10 +74,10 @@ def doc_analyze(
table_enable
=
True
,
):
"""
适当调大MIN_BATCH_INFERENCE_SIZE可以提高性能,
可能会增加显存使用量
,
可通过环境变量MINERU_MIN_BATCH_INFERENCE_SIZE设置,默认值为
128
。
适当调大MIN_BATCH_INFERENCE_SIZE可以提高性能,
更大的 MIN_BATCH_INFERENCE_SIZE会消耗更多内存
,
可通过环境变量MINERU_MIN_BATCH_INFERENCE_SIZE设置,默认值为
384
。
"""
min_batch_inference_size
=
int
(
os
.
environ
.
get
(
'MINERU_MIN_BATCH_INFERENCE_SIZE'
,
128
))
min_batch_inference_size
=
int
(
os
.
environ
.
get
(
'MINERU_MIN_BATCH_INFERENCE_SIZE'
,
384
))
# 收集所有页面信息
all_pages_info
=
[]
# 存储(dataset_index, page_index, img, ocr, lang, width, height)
...
...
mineru/model/mfr/unimernet/Unimernet.py
View file @
a5e41791
...
...
@@ -115,7 +115,7 @@ class UnimernetModel(object):
mf_img
=
mf_img
.
to
(
dtype
=
self
.
model
.
dtype
)
mf_img
=
mf_img
.
to
(
self
.
device
)
with
torch
.
no_grad
():
output
=
self
.
model
.
generate
({
"image"
:
mf_img
})
output
=
self
.
model
.
generate
({
"image"
:
mf_img
}
,
batch_size
=
batch_size
)
mfr_res
.
extend
(
output
[
"fixed_str"
])
# 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
...
...
mineru/model/mfr/unimernet/unimernet_hf/modeling_unimernet.py
View file @
a5e41791
...
...
@@ -468,7 +468,7 @@ class UnimernetModel(VisionEncoderDecoderModel):
).
loss
return
{
"loss"
:
loss
}
def
generate
(
self
,
samples
,
do_sample
:
bool
=
False
,
temperature
:
float
=
0.2
,
top_p
:
float
=
0.95
):
def
generate
(
self
,
samples
,
do_sample
:
bool
=
False
,
temperature
:
float
=
0.2
,
top_p
:
float
=
0.95
,
batch_size
=
64
):
pixel_values
=
samples
[
"image"
]
num_channels
=
pixel_values
.
shape
[
1
]
if
num_channels
==
1
:
...
...
@@ -478,7 +478,13 @@ class UnimernetModel(VisionEncoderDecoderModel):
if
do_sample
:
kwargs
[
"temperature"
]
=
temperature
kwargs
[
"top_p"
]
=
top_p
if
self
.
tokenizer
.
tokenizer
.
model_max_length
>
1152
:
if
batch_size
<=
32
:
self
.
tokenizer
.
tokenizer
.
model_max_length
=
1152
# 6g
else
:
self
.
tokenizer
.
tokenizer
.
model_max_length
=
1344
# 8g
outputs
=
super
().
generate
(
pixel_values
=
pixel_values
,
max_new_tokens
=
self
.
tokenizer
.
tokenizer
.
model_max_length
,
# required
...
...
mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py
View file @
a5e41791
...
...
@@ -88,7 +88,7 @@ class PytorchPaddleOCR(TextSystem):
kwargs
[
'det_model_path'
]
=
det_model_path
kwargs
[
'rec_model_path'
]
=
rec_model_path
kwargs
[
'rec_char_dict_path'
]
=
os
.
path
.
join
(
root_dir
,
'pytorchocr'
,
'utils'
,
'resources'
,
'dict'
,
dict_file
)
#
kwargs['rec_batch_num'] =
8
kwargs
[
'rec_batch_num'
]
=
16
kwargs
[
'device'
]
=
device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment