Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
99d4c97a
Commit
99d4c97a
authored
May 28, 2025
by
speta
Browse files
支持batch-ocr-det,速度约提升3倍(200页pdf在3090上)
parent
f5016508
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
272 additions
and
38 deletions
+272
-38
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+150
-38
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py
..._modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py
+122
-0
No files found.
magic_pdf/model/batch_analyze.py
View file @
99d4c97a
...
@@ -2,6 +2,8 @@ import time
...
@@ -2,6 +2,8 @@ import time
import
cv2
import
cv2
from
loguru
import
logger
from
loguru
import
logger
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
collections
import
defaultdict
import
numpy
as
np
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.model.sub_modules.model_init
import
AtomModelSingleton
from
magic_pdf.model.sub_modules.model_init
import
AtomModelSingleton
...
@@ -16,27 +18,28 @@ MFR_BASE_BATCH_SIZE = 16
...
@@ -16,27 +18,28 @@ MFR_BASE_BATCH_SIZE = 16
class
BatchAnalyze
:
class
BatchAnalyze
:
def
__init__
(
self
,
model_manager
,
batch_ratio
:
int
,
show_log
,
layout_model
,
formula_enable
,
table_enable
):
def
__init__
(
self
,
model_manager
,
batch_ratio
:
int
,
show_log
,
layout_model
,
formula_enable
,
table_enable
,
enable_ocr_det_batch
=
True
):
self
.
model_manager
=
model_manager
self
.
model_manager
=
model_manager
self
.
batch_ratio
=
batch_ratio
self
.
batch_ratio
=
batch_ratio
self
.
show_log
=
show_log
self
.
show_log
=
show_log
self
.
layout_model
=
layout_model
self
.
layout_model
=
layout_model
self
.
formula_enable
=
formula_enable
self
.
formula_enable
=
formula_enable
self
.
table_enable
=
table_enable
self
.
table_enable
=
table_enable
self
.
enable_ocr_det_batch
=
enable_ocr_det_batch
def
__call__
(
self
,
images_with_extra_info
:
list
)
->
list
:
def
__call__
(
self
,
images_with_extra_info
:
list
)
->
list
:
if
len
(
images_with_extra_info
)
==
0
:
if
len
(
images_with_extra_info
)
==
0
:
return
[]
return
[]
images_layout_res
=
[]
images_layout_res
=
[]
layout_start_time
=
time
.
time
()
layout_start_time
=
time
.
time
()
self
.
model
=
self
.
model_manager
.
get_model
(
self
.
model
=
self
.
model_manager
.
get_model
(
ocr
=
True
,
ocr
=
True
,
show_log
=
self
.
show_log
,
show_log
=
self
.
show_log
,
lang
=
None
,
lang
=
None
,
layout_model
=
self
.
layout_model
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
,
table_enable
=
self
.
table_enable
,
)
)
images
=
[
image
for
image
,
_
,
_
in
images_with_extra_info
]
images
=
[
image
for
image
,
_
,
_
in
images_with_extra_info
]
...
@@ -101,43 +104,152 @@ class BatchAnalyze:
...
@@ -101,43 +104,152 @@ class BatchAnalyze:
get_res_list_from_layout_res
(
layout_res
)
get_res_list_from_layout_res
(
layout_res
)
)
)
ocr_res_list_all_page
.
append
({
'ocr_res_list'
:
ocr_res_list
,
ocr_res_list_all_page
.
append
({
'lang'
:
_lang
,
'ocr_res_list'
:
ocr_res_list
,
'ocr_enable'
:
ocr_enable
,
'lang'
:
_lang
,
'np_array_img'
:
np_array_img
,
'ocr_enable'
:
ocr_enable
,
'single_page_mfdetrec_res'
:
single_page_mfdetrec_res
,
'np_array_img'
:
np_array_img
,
'layout_res'
:
layout_res
,
'single_page_mfdetrec_res'
:
single_page_mfdetrec_res
,
})
'layout_res'
:
layout_res
,
})
for
table_res
in
table_res_list
:
for
table_res
in
table_res_list
:
table_img
,
_
=
crop_img
(
table_res
,
np_array_img
)
table_img
,
_
=
crop_img
(
table_res
,
np_array_img
)
table_res_list_all_page
.
append
({
'table_res'
:
table_res
,
table_res_list_all_page
.
append
({
'lang'
:
_lang
,
'table_res'
:
table_res
,
'table_img'
:
table_img
,
'lang'
:
_lang
,
})
'table_img'
:
table_img
,
})
# 文本框检测
det_start
=
time
.
time
()
# OCR检测处理
det_count
=
0
if
self
.
enable_ocr_det_batch
:
# for ocr_res_list_dict in ocr_res_list_all_page:
# 批处理模式 - 按语言和分辨率分组
for
ocr_res_list_dict
in
tqdm
(
ocr_res_list_all_page
,
desc
=
"OCR-det Predict"
):
# 收集所有需要OCR检测的裁剪图像
# Process each area that requires OCR processing
all_cropped_images_info
=
[]
_lang
=
ocr_res_list_dict
[
'lang'
]
# Get OCR results for this language's images
for
ocr_res_list_dict
in
tqdm
(
ocr_res_list_all_page
,
desc
=
"Preparing OCR-det batches"
):
atom_model_manager
=
AtomModelSingleton
()
_lang
=
ocr_res_list_dict
[
'lang'
]
ocr_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
'ocr'
,
for
res
in
ocr_res_list_dict
[
'ocr_res_list'
]:
ocr_show_log
=
False
,
new_image
,
useful_list
=
crop_img
(
det_db_box_thresh
=
0.3
,
res
,
ocr_res_list_dict
[
'np_array_img'
],
crop_paste_x
=
50
,
crop_paste_y
=
50
lang
=
_lang
)
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
for
res
in
ocr_res_list_dict
[
'ocr_res_list'
]:
ocr_res_list_dict
[
'single_page_mfdetrec_res'
],
useful_list
new_image
,
useful_list
=
crop_img
(
)
res
,
ocr_res_list_dict
[
'np_array_img'
],
crop_paste_x
=
50
,
crop_paste_y
=
50
# BGR转换
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
all_cropped_images_info
.
append
((
new_image
,
useful_list
,
ocr_res_list_dict
,
res
,
adjusted_mfdetrec_res
,
_lang
))
# 按语言分组
lang_groups
=
defaultdict
(
list
)
for
crop_info
in
all_cropped_images_info
:
lang
=
crop_info
[
5
]
lang_groups
[
lang
].
append
(
crop_info
)
# 对每种语言按分辨率分组并批处理
for
lang
,
lang_crop_list
in
lang_groups
.
items
():
if
not
lang_crop_list
:
continue
logger
.
info
(
f
"Processing OCR detection for language
{
lang
}
with
{
len
(
lang_crop_list
)
}
images"
)
# 获取OCR模型
atom_model_manager
=
AtomModelSingleton
()
ocr_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
'ocr'
,
ocr_show_log
=
False
,
det_db_box_thresh
=
0.3
,
lang
=
lang
)
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
ocr_res_list_dict
[
'single_page_mfdetrec_res'
],
useful_list
# 按分辨率分组并同时完成padding
resolution_groups
=
defaultdict
(
list
)
for
crop_info
in
lang_crop_list
:
cropped_img
=
crop_info
[
0
]
h
,
w
=
cropped_img
.
shape
[:
2
]
# 使用更大的分组容差,减少分组数量
# 将尺寸标准化到32的倍数
normalized_h
=
((
h
+
32
)
//
32
)
*
32
# 向上取整到32的倍数
normalized_w
=
((
w
+
32
)
//
32
)
*
32
group_key
=
(
normalized_h
,
normalized_w
)
resolution_groups
[
group_key
].
append
(
crop_info
)
# 对每个分辨率组进行批处理
for
group_key
,
group_crops
in
tqdm
(
resolution_groups
.
items
(),
desc
=
f
"OCR-det
{
lang
}
"
):
raw_images
=
[
crop_info
[
0
]
for
crop_info
in
group_crops
]
# 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
max_h
=
max
(
img
.
shape
[
0
]
for
img
in
raw_images
)
max_w
=
max
(
img
.
shape
[
1
]
for
img
in
raw_images
)
target_h
=
((
max_h
+
32
-
1
)
//
32
)
*
32
target_w
=
((
max_w
+
32
-
1
)
//
32
)
*
32
# 对所有图像进行padding到统一尺寸
batch_images
=
[]
for
img
in
raw_images
:
h
,
w
=
img
.
shape
[:
2
]
# 创建目标尺寸的白色背景
padded_img
=
np
.
ones
((
target_h
,
target_w
,
3
),
dtype
=
np
.
uint8
)
*
255
# 将原图像粘贴到左上角
padded_img
[:
h
,
:
w
]
=
img
batch_images
.
append
(
padded_img
)
# 批处理检测
batch_size
=
min
(
len
(
batch_images
),
self
.
batch_ratio
*
16
)
# 增加批处理大小
logger
.
debug
(
f
"OCR-det batch:
{
batch_size
}
images, target size:
{
target_h
}
x
{
target_w
}
"
)
batch_results
=
ocr_model
.
text_detector
.
batch_predict
(
batch_images
,
batch_size
)
# 处理批处理结果
for
i
,
(
crop_info
,
(
dt_boxes
,
elapse
))
in
enumerate
(
zip
(
group_crops
,
batch_results
)):
new_image
,
useful_list
,
ocr_res_list_dict
,
res
,
adjusted_mfdetrec_res
,
_lang
=
crop_info
if
dt_boxes
is
not
None
:
# 构造OCR结果格式 - 每个box应该是4个点的列表
ocr_res
=
[
box
.
tolist
()
for
box
in
dt_boxes
]
if
ocr_res
:
ocr_result_list
=
get_ocr_result_list
(
ocr_res
,
useful_list
,
ocr_res_list_dict
[
'ocr_enable'
],
new_image
,
_lang
)
if
res
[
"category_id"
]
==
3
:
# ocr_result_list中所有bbox的面积之和
ocr_res_area
=
sum
(
get_coords_and_area
(
ocr_res_item
)[
4
]
for
ocr_res_item
in
ocr_result_list
if
'poly'
in
ocr_res_item
)
# 求ocr_res_area和res的面积的比值
res_area
=
get_coords_and_area
(
res
)[
4
]
if
res_area
>
0
:
ratio
=
ocr_res_area
/
res_area
if
ratio
>
0.25
:
res
[
"category_id"
]
=
1
else
:
continue
ocr_res_list_dict
[
'layout_res'
].
extend
(
ocr_result_list
)
else
:
# 原始单张处理模式
for
ocr_res_list_dict
in
tqdm
(
ocr_res_list_all_page
,
desc
=
"OCR-det Predict"
):
# Process each area that requires OCR processing
_lang
=
ocr_res_list_dict
[
'lang'
]
# Get OCR results for this language's images
atom_model_manager
=
AtomModelSingleton
()
ocr_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
'ocr'
,
ocr_show_log
=
False
,
det_db_box_thresh
=
0.3
,
lang
=
_lang
)
)
for
res
in
ocr_res_list_dict
[
'ocr_res_list'
]:
new_image
,
useful_list
=
crop_img
(
res
,
ocr_res_list_dict
[
'np_array_img'
],
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
ocr_res_list_dict
[
'single_page_mfdetrec_res'
],
useful_list
)
# OCR-det
# OCR-det
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
...
...
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py
View file @
99d4c97a
...
@@ -117,6 +117,128 @@ class TextDetector(BaseOCRV20):
...
@@ -117,6 +117,128 @@ class TextDetector(BaseOCRV20):
self
.
net
.
eval
()
self
.
net
.
eval
()
self
.
net
.
to
(
self
.
device
)
self
.
net
.
to
(
self
.
device
)
def
_batch_process_same_size
(
self
,
img_list
):
"""
对相同尺寸的图像进行批处理
Args:
img_list: 相同尺寸的图像列表
Returns:
batch_results: 批处理结果列表
total_elapse: 总耗时
"""
starttime
=
time
.
time
()
# 预处理所有图像
batch_data
=
[]
batch_shapes
=
[]
ori_imgs
=
[]
for
img
in
img_list
:
ori_im
=
img
.
copy
()
ori_imgs
.
append
(
ori_im
)
data
=
{
'image'
:
img
}
data
=
transform
(
data
,
self
.
preprocess_op
)
if
data
is
None
:
# 如果预处理失败,返回空结果
return
[(
None
,
0
)
for
_
in
img_list
],
0
img_processed
,
shape_list
=
data
batch_data
.
append
(
img_processed
)
batch_shapes
.
append
(
shape_list
)
# 堆叠成批处理张量
try
:
batch_tensor
=
np
.
stack
(
batch_data
,
axis
=
0
)
batch_shapes
=
np
.
stack
(
batch_shapes
,
axis
=
0
)
except
Exception
as
e
:
# 如果堆叠失败,回退到逐个处理
batch_results
=
[]
for
img
in
img_list
:
dt_boxes
,
elapse
=
self
.
__call__
(
img
)
batch_results
.
append
((
dt_boxes
,
elapse
))
return
batch_results
,
time
.
time
()
-
starttime
# 批处理推理
with
torch
.
no_grad
():
inp
=
torch
.
from_numpy
(
batch_tensor
)
inp
=
inp
.
to
(
self
.
device
)
outputs
=
self
.
net
(
inp
)
# 处理输出
preds
=
{}
if
self
.
det_algorithm
==
"EAST"
:
preds
[
'f_geo'
]
=
outputs
[
'f_geo'
].
cpu
().
numpy
()
preds
[
'f_score'
]
=
outputs
[
'f_score'
].
cpu
().
numpy
()
elif
self
.
det_algorithm
==
'SAST'
:
preds
[
'f_border'
]
=
outputs
[
'f_border'
].
cpu
().
numpy
()
preds
[
'f_score'
]
=
outputs
[
'f_score'
].
cpu
().
numpy
()
preds
[
'f_tco'
]
=
outputs
[
'f_tco'
].
cpu
().
numpy
()
preds
[
'f_tvo'
]
=
outputs
[
'f_tvo'
].
cpu
().
numpy
()
elif
self
.
det_algorithm
in
[
'DB'
,
'PSE'
,
'DB++'
]:
preds
[
'maps'
]
=
outputs
[
'maps'
].
cpu
().
numpy
()
elif
self
.
det_algorithm
==
'FCE'
:
for
i
,
(
k
,
output
)
in
enumerate
(
outputs
.
items
()):
preds
[
'level_{}'
.
format
(
i
)]
=
output
.
cpu
().
numpy
()
else
:
raise
NotImplementedError
# 后处理每个图像的结果
batch_results
=
[]
total_elapse
=
time
.
time
()
-
starttime
for
i
in
range
(
len
(
img_list
)):
# 提取单个图像的预测结果
single_preds
=
{}
for
key
,
value
in
preds
.
items
():
if
isinstance
(
value
,
np
.
ndarray
):
single_preds
[
key
]
=
value
[
i
:
i
+
1
]
# 保持批次维度
else
:
single_preds
[
key
]
=
value
# 后处理
post_result
=
self
.
postprocess_op
(
single_preds
,
batch_shapes
[
i
:
i
+
1
])
dt_boxes
=
post_result
[
0
][
'points'
]
# 过滤和裁剪检测框
if
(
self
.
det_algorithm
==
"SAST"
and
self
.
det_sast_polygon
)
or
(
self
.
det_algorithm
in
[
"PSE"
,
"FCE"
]
and
self
.
postprocess_op
.
box_type
==
'poly'
):
dt_boxes
=
self
.
filter_tag_det_res_only_clip
(
dt_boxes
,
ori_imgs
[
i
].
shape
)
else
:
dt_boxes
=
self
.
filter_tag_det_res
(
dt_boxes
,
ori_imgs
[
i
].
shape
)
batch_results
.
append
((
dt_boxes
,
total_elapse
/
len
(
img_list
)))
return
batch_results
,
total_elapse
def
batch_predict
(
self
,
img_list
,
max_batch_size
=
8
):
"""
批处理预测方法,支持多张图像同时检测
Args:
img_list: 图像列表
max_batch_size: 最大批处理大小
Returns:
batch_results: 批处理结果列表,每个元素为(dt_boxes, elapse)
"""
if
not
img_list
:
return
[]
batch_results
=
[]
# 分批处理
for
i
in
range
(
0
,
len
(
img_list
),
max_batch_size
):
batch_imgs
=
img_list
[
i
:
i
+
max_batch_size
]
# assert尺寸一致
batch_dt_boxes
,
batch_elapse
=
self
.
_batch_process_same_size
(
batch_imgs
)
batch_results
.
extend
(
batch_dt_boxes
)
return
batch_results
def
order_points_clockwise
(
self
,
pts
):
def
order_points_clockwise
(
self
,
pts
):
"""
"""
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment