Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
99d4c97a
"...schedulers/scheduling_edm_dpmsolver_multistep.py" did not exist on "aaaec06487ffc3cd5d1c1d4e1a64af80dbaf477b"
Commit
99d4c97a
authored
May 28, 2025
by
speta
Browse files
支持batch-ocr-det,速度约提升3倍(200页pdf在3090上)
parent
f5016508
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
272 additions
and
38 deletions
+272
-38
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+150
-38
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py
..._modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py
+122
-0
No files found.
magic_pdf/model/batch_analyze.py
View file @
99d4c97a
...
...
@@ -2,6 +2,8 @@ import time
import
cv2
from
loguru
import
logger
from
tqdm
import
tqdm
from
collections
import
defaultdict
import
numpy
as
np
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.model.sub_modules.model_init
import
AtomModelSingleton
...
...
@@ -16,27 +18,28 @@ MFR_BASE_BATCH_SIZE = 16
class
BatchAnalyze
:
def
__init__
(
self
,
model_manager
,
batch_ratio
:
int
,
show_log
,
layout_model
,
formula_enable
,
table_enable
):
def
__init__
(
self
,
model_manager
,
batch_ratio
:
int
,
show_log
,
layout_model
,
formula_enable
,
table_enable
,
enable_ocr_det_batch
=
True
):
self
.
model_manager
=
model_manager
self
.
batch_ratio
=
batch_ratio
self
.
show_log
=
show_log
self
.
layout_model
=
layout_model
self
.
formula_enable
=
formula_enable
self
.
table_enable
=
table_enable
self
.
enable_ocr_det_batch
=
enable_ocr_det_batch
def
__call__
(
self
,
images_with_extra_info
:
list
)
->
list
:
if
len
(
images_with_extra_info
)
==
0
:
return
[]
images_layout_res
=
[]
layout_start_time
=
time
.
time
()
self
.
model
=
self
.
model_manager
.
get_model
(
ocr
=
True
,
show_log
=
self
.
show_log
,
lang
=
None
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
,
lang
=
None
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
,
)
images
=
[
image
for
image
,
_
,
_
in
images_with_extra_info
]
...
...
@@ -101,43 +104,152 @@ class BatchAnalyze:
get_res_list_from_layout_res
(
layout_res
)
)
ocr_res_list_all_page
.
append
({
'ocr_res_list'
:
ocr_res_list
,
'lang'
:
_lang
,
'ocr_enable'
:
ocr_enable
,
'np_array_img'
:
np_array_img
,
'single_page_mfdetrec_res'
:
single_page_mfdetrec_res
,
'layout_res'
:
layout_res
,
})
ocr_res_list_all_page
.
append
({
'ocr_res_list'
:
ocr_res_list
,
'lang'
:
_lang
,
'ocr_enable'
:
ocr_enable
,
'np_array_img'
:
np_array_img
,
'single_page_mfdetrec_res'
:
single_page_mfdetrec_res
,
'layout_res'
:
layout_res
,
})
for
table_res
in
table_res_list
:
table_img
,
_
=
crop_img
(
table_res
,
np_array_img
)
table_res_list_all_page
.
append
({
'table_res'
:
table_res
,
'lang'
:
_lang
,
'table_img'
:
table_img
,
})
# 文本框检测
det_start
=
time
.
time
()
det_count
=
0
# for ocr_res_list_dict in ocr_res_list_all_page:
for
ocr_res_list_dict
in
tqdm
(
ocr_res_list_all_page
,
desc
=
"OCR-det Predict"
):
# Process each area that requires OCR processing
_lang
=
ocr_res_list_dict
[
'lang'
]
# Get OCR results for this language's images
atom_model_manager
=
AtomModelSingleton
()
ocr_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
'ocr'
,
ocr_show_log
=
False
,
det_db_box_thresh
=
0.3
,
lang
=
_lang
)
for
res
in
ocr_res_list_dict
[
'ocr_res_list'
]:
new_image
,
useful_list
=
crop_img
(
res
,
ocr_res_list_dict
[
'np_array_img'
],
crop_paste_x
=
50
,
crop_paste_y
=
50
table_res_list_all_page
.
append
({
'table_res'
:
table_res
,
'lang'
:
_lang
,
'table_img'
:
table_img
,
})
# OCR检测处理
if
self
.
enable_ocr_det_batch
:
# 批处理模式 - 按语言和分辨率分组
# 收集所有需要OCR检测的裁剪图像
all_cropped_images_info
=
[]
for
ocr_res_list_dict
in
tqdm
(
ocr_res_list_all_page
,
desc
=
"Preparing OCR-det batches"
):
_lang
=
ocr_res_list_dict
[
'lang'
]
for
res
in
ocr_res_list_dict
[
'ocr_res_list'
]:
new_image
,
useful_list
=
crop_img
(
res
,
ocr_res_list_dict
[
'np_array_img'
],
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
ocr_res_list_dict
[
'single_page_mfdetrec_res'
],
useful_list
)
# BGR转换
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
all_cropped_images_info
.
append
((
new_image
,
useful_list
,
ocr_res_list_dict
,
res
,
adjusted_mfdetrec_res
,
_lang
))
# 按语言分组
lang_groups
=
defaultdict
(
list
)
for
crop_info
in
all_cropped_images_info
:
lang
=
crop_info
[
5
]
lang_groups
[
lang
].
append
(
crop_info
)
# 对每种语言按分辨率分组并批处理
for
lang
,
lang_crop_list
in
lang_groups
.
items
():
if
not
lang_crop_list
:
continue
logger
.
info
(
f
"Processing OCR detection for language
{
lang
}
with
{
len
(
lang_crop_list
)
}
images"
)
# 获取OCR模型
atom_model_manager
=
AtomModelSingleton
()
ocr_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
'ocr'
,
ocr_show_log
=
False
,
det_db_box_thresh
=
0.3
,
lang
=
lang
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
ocr_res_list_dict
[
'single_page_mfdetrec_res'
],
useful_list
# 按分辨率分组并同时完成padding
resolution_groups
=
defaultdict
(
list
)
for
crop_info
in
lang_crop_list
:
cropped_img
=
crop_info
[
0
]
h
,
w
=
cropped_img
.
shape
[:
2
]
# 使用更大的分组容差,减少分组数量
# 将尺寸标准化到32的倍数
normalized_h
=
((
h
+
32
)
//
32
)
*
32
# 向上取整到32的倍数
normalized_w
=
((
w
+
32
)
//
32
)
*
32
group_key
=
(
normalized_h
,
normalized_w
)
resolution_groups
[
group_key
].
append
(
crop_info
)
# 对每个分辨率组进行批处理
for
group_key
,
group_crops
in
tqdm
(
resolution_groups
.
items
(),
desc
=
f
"OCR-det
{
lang
}
"
):
raw_images
=
[
crop_info
[
0
]
for
crop_info
in
group_crops
]
# 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
max_h
=
max
(
img
.
shape
[
0
]
for
img
in
raw_images
)
max_w
=
max
(
img
.
shape
[
1
]
for
img
in
raw_images
)
target_h
=
((
max_h
+
32
-
1
)
//
32
)
*
32
target_w
=
((
max_w
+
32
-
1
)
//
32
)
*
32
# 对所有图像进行padding到统一尺寸
batch_images
=
[]
for
img
in
raw_images
:
h
,
w
=
img
.
shape
[:
2
]
# 创建目标尺寸的白色背景
padded_img
=
np
.
ones
((
target_h
,
target_w
,
3
),
dtype
=
np
.
uint8
)
*
255
# 将原图像粘贴到左上角
padded_img
[:
h
,
:
w
]
=
img
batch_images
.
append
(
padded_img
)
# 批处理检测
batch_size
=
min
(
len
(
batch_images
),
self
.
batch_ratio
*
16
)
# 增加批处理大小
logger
.
debug
(
f
"OCR-det batch:
{
batch_size
}
images, target size:
{
target_h
}
x
{
target_w
}
"
)
batch_results
=
ocr_model
.
text_detector
.
batch_predict
(
batch_images
,
batch_size
)
# 处理批处理结果
for
i
,
(
crop_info
,
(
dt_boxes
,
elapse
))
in
enumerate
(
zip
(
group_crops
,
batch_results
)):
new_image
,
useful_list
,
ocr_res_list_dict
,
res
,
adjusted_mfdetrec_res
,
_lang
=
crop_info
if
dt_boxes
is
not
None
:
# 构造OCR结果格式 - 每个box应该是4个点的列表
ocr_res
=
[
box
.
tolist
()
for
box
in
dt_boxes
]
if
ocr_res
:
ocr_result_list
=
get_ocr_result_list
(
ocr_res
,
useful_list
,
ocr_res_list_dict
[
'ocr_enable'
],
new_image
,
_lang
)
if
res
[
"category_id"
]
==
3
:
# ocr_result_list中所有bbox的面积之和
ocr_res_area
=
sum
(
get_coords_and_area
(
ocr_res_item
)[
4
]
for
ocr_res_item
in
ocr_result_list
if
'poly'
in
ocr_res_item
)
# 求ocr_res_area和res的面积的比值
res_area
=
get_coords_and_area
(
res
)[
4
]
if
res_area
>
0
:
ratio
=
ocr_res_area
/
res_area
if
ratio
>
0.25
:
res
[
"category_id"
]
=
1
else
:
continue
ocr_res_list_dict
[
'layout_res'
].
extend
(
ocr_result_list
)
else
:
# 原始单张处理模式
for
ocr_res_list_dict
in
tqdm
(
ocr_res_list_all_page
,
desc
=
"OCR-det Predict"
):
# Process each area that requires OCR processing
_lang
=
ocr_res_list_dict
[
'lang'
]
# Get OCR results for this language's images
atom_model_manager
=
AtomModelSingleton
()
ocr_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
'ocr'
,
ocr_show_log
=
False
,
det_db_box_thresh
=
0.3
,
lang
=
_lang
)
for
res
in
ocr_res_list_dict
[
'ocr_res_list'
]:
new_image
,
useful_list
=
crop_img
(
res
,
ocr_res_list_dict
[
'np_array_img'
],
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
ocr_res_list_dict
[
'single_page_mfdetrec_res'
],
useful_list
)
# OCR-det
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
...
...
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py
View file @
99d4c97a
...
...
@@ -117,6 +117,128 @@ class TextDetector(BaseOCRV20):
self
.
net
.
eval
()
self
.
net
.
to
(
self
.
device
)
def
_batch_process_same_size
(
self
,
img_list
):
"""
对相同尺寸的图像进行批处理
Args:
img_list: 相同尺寸的图像列表
Returns:
batch_results: 批处理结果列表
total_elapse: 总耗时
"""
starttime
=
time
.
time
()
# 预处理所有图像
batch_data
=
[]
batch_shapes
=
[]
ori_imgs
=
[]
for
img
in
img_list
:
ori_im
=
img
.
copy
()
ori_imgs
.
append
(
ori_im
)
data
=
{
'image'
:
img
}
data
=
transform
(
data
,
self
.
preprocess_op
)
if
data
is
None
:
# 如果预处理失败,返回空结果
return
[(
None
,
0
)
for
_
in
img_list
],
0
img_processed
,
shape_list
=
data
batch_data
.
append
(
img_processed
)
batch_shapes
.
append
(
shape_list
)
# 堆叠成批处理张量
try
:
batch_tensor
=
np
.
stack
(
batch_data
,
axis
=
0
)
batch_shapes
=
np
.
stack
(
batch_shapes
,
axis
=
0
)
except
Exception
as
e
:
# 如果堆叠失败,回退到逐个处理
batch_results
=
[]
for
img
in
img_list
:
dt_boxes
,
elapse
=
self
.
__call__
(
img
)
batch_results
.
append
((
dt_boxes
,
elapse
))
return
batch_results
,
time
.
time
()
-
starttime
# 批处理推理
with
torch
.
no_grad
():
inp
=
torch
.
from_numpy
(
batch_tensor
)
inp
=
inp
.
to
(
self
.
device
)
outputs
=
self
.
net
(
inp
)
# 处理输出
preds
=
{}
if
self
.
det_algorithm
==
"EAST"
:
preds
[
'f_geo'
]
=
outputs
[
'f_geo'
].
cpu
().
numpy
()
preds
[
'f_score'
]
=
outputs
[
'f_score'
].
cpu
().
numpy
()
elif
self
.
det_algorithm
==
'SAST'
:
preds
[
'f_border'
]
=
outputs
[
'f_border'
].
cpu
().
numpy
()
preds
[
'f_score'
]
=
outputs
[
'f_score'
].
cpu
().
numpy
()
preds
[
'f_tco'
]
=
outputs
[
'f_tco'
].
cpu
().
numpy
()
preds
[
'f_tvo'
]
=
outputs
[
'f_tvo'
].
cpu
().
numpy
()
elif
self
.
det_algorithm
in
[
'DB'
,
'PSE'
,
'DB++'
]:
preds
[
'maps'
]
=
outputs
[
'maps'
].
cpu
().
numpy
()
elif
self
.
det_algorithm
==
'FCE'
:
for
i
,
(
k
,
output
)
in
enumerate
(
outputs
.
items
()):
preds
[
'level_{}'
.
format
(
i
)]
=
output
.
cpu
().
numpy
()
else
:
raise
NotImplementedError
# 后处理每个图像的结果
batch_results
=
[]
total_elapse
=
time
.
time
()
-
starttime
for
i
in
range
(
len
(
img_list
)):
# 提取单个图像的预测结果
single_preds
=
{}
for
key
,
value
in
preds
.
items
():
if
isinstance
(
value
,
np
.
ndarray
):
single_preds
[
key
]
=
value
[
i
:
i
+
1
]
# 保持批次维度
else
:
single_preds
[
key
]
=
value
# 后处理
post_result
=
self
.
postprocess_op
(
single_preds
,
batch_shapes
[
i
:
i
+
1
])
dt_boxes
=
post_result
[
0
][
'points'
]
# 过滤和裁剪检测框
if
(
self
.
det_algorithm
==
"SAST"
and
self
.
det_sast_polygon
)
or
(
self
.
det_algorithm
in
[
"PSE"
,
"FCE"
]
and
self
.
postprocess_op
.
box_type
==
'poly'
):
dt_boxes
=
self
.
filter_tag_det_res_only_clip
(
dt_boxes
,
ori_imgs
[
i
].
shape
)
else
:
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_imgs
[
i
].
shape
)
batch_results
.
append
((
dt_boxes
,
total_elapse
/
len
(
img_list
)))
return
batch_results
,
total_elapse
def
batch_predict
(
self
,
img_list
,
max_batch_size
=
8
):
"""
批处理预测方法,支持多张图像同时检测
Args:
img_list: 图像列表
max_batch_size: 最大批处理大小
Returns:
batch_results: 批处理结果列表,每个元素为(dt_boxes, elapse)
"""
if
not
img_list
:
return
[]
batch_results
=
[]
# 分批处理
for
i
in
range
(
0
,
len
(
img_list
),
max_batch_size
):
batch_imgs
=
img_list
[
i
:
i
+
max_batch_size
]
# assert尺寸一致
batch_dt_boxes
,
batch_elapse
=
self
.
_batch_process_same_size
(
batch_imgs
)
batch_results
.
extend
(
batch_dt_boxes
)
return
batch_results
def
order_points_clockwise
(
self
,
pts
):
"""
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment