Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
bbba2a12
Commit
bbba2a12
authored
Mar 26, 2025
by
icecraft
Browse files
feat: batch inference with ocr and lang flag
parent
2c8470b0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
60 deletions
+47
-60
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+20
-8
magic_pdf/model/doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+24
-38
magic_pdf/tools/common.py
magic_pdf/tools/common.py
+3
-14
No files found.
magic_pdf/model/batch_analyze.py
View file @
bbba2a12
...
...
@@ -17,13 +17,25 @@ MFR_BASE_BATCH_SIZE = 16
class
BatchAnalyze
:
def
__init__
(
self
,
model
:
CustomPEKModel
,
batch_ratio
:
int
):
self
.
model
=
model
def
__init__
(
self
,
model
_manager
,
batch_ratio
:
int
,
show_log
,
layout_model
,
formula_enable
,
table_enable
):
self
.
model
_manager
=
model_manager
self
.
batch_ratio
=
batch_ratio
def
__call__
(
self
,
images
:
list
)
->
list
:
self
.
show_log
=
show_log
self
.
layout_model
=
layout_model
self
.
formula_enable
=
formula_enable
self
.
table_enable
=
table_enable
def
__call__
(
self
,
images_with_extra_info
:
list
)
->
list
:
if
len
(
images_with_extra_info
)
==
0
:
return
[]
images_layout_res
=
[]
layout_start_time
=
time
.
time
()
_
,
fst_ocr
,
fst_lang
=
images_with_extra_info
[
0
]
self
.
model
=
self
.
model_manager
.
get_model
(
fst_ocr
,
self
.
show_log
,
fst_lang
,
self
.
layout_model
,
self
.
formula_enable
,
self
.
table_enable
)
images
=
[
image
for
image
,
_
,
_
in
images_with_extra_info
]
if
self
.
model
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
# layoutlmv3
for
image
in
images
:
...
...
@@ -79,6 +91,8 @@ class BatchAnalyze:
table_count
=
0
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for
index
in
range
(
len
(
images
)):
_
,
ocr_enable
,
_lang
=
images_with_extra_info
[
index
]
self
.
model
=
self
.
model_manager
.
get_model
(
ocr_enable
,
self
.
show_log
,
_lang
,
self
.
layout_model
,
self
.
formula_enable
,
self
.
table_enable
)
layout_res
=
images_layout_res
[
index
]
np_array_img
=
images
[
index
]
...
...
@@ -99,7 +113,7 @@ class BatchAnalyze:
# OCR recognition
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
if
self
.
model
.
apply_ocr
:
if
ocr_enable
:
ocr_res
=
self
.
model
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
...
...
@@ -159,9 +173,7 @@ class BatchAnalyze:
table_count
+=
len
(
table_res_list
)
if
self
.
model
.
apply_ocr
:
logger
.
info
(
f
'ocr time:
{
round
(
ocr_time
,
2
)
}
, image num:
{
ocr_count
}
'
)
else
:
logger
.
info
(
f
'det time:
{
round
(
ocr_time
,
2
)
}
, image num:
{
ocr_count
}
'
)
logger
.
info
(
f
'det or det time costs:
{
round
(
ocr_time
,
2
)
}
, image num:
{
ocr_count
}
'
)
if
self
.
model
.
apply_table
:
logger
.
info
(
f
'table time:
{
round
(
table_time
,
2
)
}
, image num:
{
table_count
}
'
)
...
...
magic_pdf/model/doc_analyze_by_custom_model.py
View file @
bbba2a12
...
...
@@ -15,7 +15,7 @@ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from
loguru
import
logger
from
magic_pdf.model.sub_modules.model_utils
import
get_vram
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
import
magic_pdf.model
as
model_config
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.clean_memory
import
clean_memory
...
...
@@ -150,12 +150,13 @@ def doc_analyze(
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
'img'
])
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
images_with_extra_info
=
[(
images
[
index
],
ocr
,
dataset
.
_lang
)
for
index
in
range
(
len
(
dataset
))]
if
len
(
images
)
>=
MIN_BATCH_INFERENCE_SIZE
:
batch_size
=
MIN_BATCH_INFERENCE_SIZE
batch_images
=
[
images
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images
),
batch_size
)]
batch_images
=
[
images
_with_extra_info
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images
_with_extra_info
),
batch_size
)]
else
:
batch_images
=
[
images
]
batch_images
=
[
images
_with_extra_info
]
results
=
[]
for
sn
,
batch_image
in
enumerate
(
batch_images
):
...
...
@@ -181,7 +182,7 @@ def doc_analyze(
def
batch_doc_analyze
(
datasets
:
list
[
Dataset
],
ocr
:
bool
=
False
,
parse_method
:
str
,
show_log
:
bool
=
False
,
lang
=
None
,
layout_model
=
None
,
...
...
@@ -192,47 +193,31 @@ def batch_doc_analyze(
batch_size
=
MIN_BATCH_INFERENCE_SIZE
images
=
[]
page_wh_list
=
[]
lang_list
=
[]
lang_s
=
set
()
images_with_extra_info
=
[]
for
dataset
in
datasets
:
for
index
in
range
(
len
(
dataset
)):
if
lang
is
None
or
lang
==
'auto'
:
lang_list
.
append
(
dataset
.
_lang
)
_lang
=
dataset
.
_lang
else
:
lang_list
.
append
(
lang
)
lang_s
.
add
(
lang_list
[
-
1
])
_lang
=
lang
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
'img'
])
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
if
parse_method
==
'auto'
:
images_with_extra_info
.
append
((
images
[
-
1
],
dataset
.
classify
()
==
SupportedPdfParseMethod
.
OCR
,
_lang
))
else
:
images_with_extra_info
.
append
((
images
[
-
1
],
parse_method
==
'ocr'
,
_lang
))
batch_images
=
[]
img_idx_list
=
[]
for
t_lang
in
lang_s
:
tmp_img_idx_list
=
[]
for
i
,
_lang
in
enumerate
(
lang_list
):
if
_lang
==
t_lang
:
tmp_img_idx_list
.
append
(
i
)
img_idx_list
.
extend
(
tmp_img_idx_list
)
if
batch_size
>=
len
(
tmp_img_idx_list
):
batch_images
.
append
((
t_lang
,
[
images
[
j
]
for
j
in
tmp_img_idx_list
]))
else
:
slices
=
[
tmp_img_idx_list
[
k
:
k
+
batch_size
]
for
k
in
range
(
0
,
len
(
tmp_img_idx_list
),
batch_size
)]
for
arr
in
slices
:
batch_images
.
append
((
t_lang
,
[
images
[
j
]
for
j
in
arr
]))
unorder_results
=
[]
for
sn
,
(
_lang
,
batch_image
)
in
enumerate
(
batch_images
):
_
,
result
=
may_batch_image_analyze
(
batch_image
,
sn
,
ocr
,
show_log
,
_lang
,
layout_model
,
formula_enable
,
table_enable
)
unorder_results
.
extend
(
result
)
results
=
[
None
]
*
len
(
img_idx_list
)
for
i
,
idx
in
enumerate
(
img_idx_list
):
results
[
idx
]
=
unorder_results
[
i
]
batch_images
=
[
images_with_extra_info
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images_with_extra_info
),
batch_size
)]
results
=
[]
for
sn
,
batch_image
in
enumerate
(
batch_images
):
_
,
result
=
may_batch_image_analyze
(
batch_image
,
sn
,
True
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
results
.
extend
(
result
)
infer_results
=
[]
from
magic_pdf.operators.models
import
InferenceResult
for
index
in
range
(
len
(
datasets
)):
dataset
=
datasets
[
index
]
...
...
@@ -248,9 +233,9 @@ def batch_doc_analyze(
def
may_batch_image_analyze
(
images
:
list
[
np
.
ndarray
],
images
_with_extra_info
:
list
[
(
np
.
ndarray
,
bool
,
str
)
],
idx
:
int
,
ocr
:
bool
=
False
,
ocr
:
bool
,
show_log
:
bool
=
False
,
lang
=
None
,
layout_model
=
None
,
...
...
@@ -267,6 +252,7 @@ def may_batch_image_analyze(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
images
=
[
image
for
image
,
_
,
_
in
images_with_extra_info
]
batch_analyze
=
False
batch_ratio
=
1
device
=
get_device
()
...
...
@@ -306,8 +292,8 @@ def may_batch_image_analyze(
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
"""
batch_model
=
BatchAnalyze
(
model
=
custom_model
,
batch_ratio
=
batch_ratio
)
results
=
batch_model
(
images
)
batch_model
=
BatchAnalyze
(
model
_manager
,
batch_ratio
,
show_log
,
layout_model
,
formula_enable
,
table_enable
)
results
=
batch_model
(
images
_with_extra_info
)
"""
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
...
...
magic_pdf/tools/common.py
View file @
bbba2a12
...
...
@@ -314,21 +314,10 @@ def batch_do_parse(
dss
.
append
(
PymuDocDataset
(
v
,
lang
=
lang
))
else
:
dss
.
append
(
v
)
dss_with_fn
=
list
(
zip
(
dss
,
pdf_file_names
))
if
parse_method
==
'auto'
:
dss_typed_txt
=
[(
i
,
x
)
for
i
,
x
in
enumerate
(
dss_with_fn
)
if
x
[
0
].
classify
()
==
SupportedPdfParseMethod
.
TXT
]
dss_typed_ocr
=
[(
i
,
x
)
for
i
,
x
in
enumerate
(
dss_with_fn
)
if
x
[
0
].
classify
()
==
SupportedPdfParseMethod
.
OCR
]
infer_results
=
[
None
]
*
len
(
dss_with_fn
)
infer_results_txt
=
batch_doc_analyze
([
x
[
1
][
0
]
for
x
in
dss_typed_txt
],
lang
=
lang
,
ocr
=
False
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
infer_results_ocr
=
batch_doc_analyze
([
x
[
1
][
0
]
for
x
in
dss_typed_ocr
],
lang
=
lang
,
ocr
=
True
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
for
i
,
infer_res
in
enumerate
(
infer_results_txt
):
infer_results
[
dss_typed_txt
[
i
][
0
]]
=
infer_res
for
i
,
infer_res
in
enumerate
(
infer_results_ocr
):
infer_results
[
dss_typed_ocr
[
i
][
0
]]
=
infer_res
else
:
infer_results
=
batch_doc_analyze
(
dss
,
lang
=
lang
,
ocr
=
parse_method
==
'ocr'
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
infer_results
=
batch_doc_analyze
(
dss
,
parse_method
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
for
idx
,
infer_result
in
enumerate
(
infer_results
):
_do_parse
(
output_dir
,
dss_with_fn
[
idx
]
[
1
]
,
dss
_with_fn
[
idx
]
[
0
]
,
infer_result
.
get_infer_res
(),
parse_method
,
debug_able
,
f_draw_span_bbox
=
f_draw_span_bbox
,
f_draw_layout_bbox
=
f_draw_layout_bbox
,
f_dump_md
=
f_dump_md
,
f_dump_middle_json
=
f_dump_middle_json
,
f_dump_model_json
=
f_dump_model_json
,
f_dump_orig_pdf
=
f_dump_orig_pdf
,
f_dump_content_list
=
f_dump_content_list
,
f_make_md_mode
=
f_make_md_mode
,
f_draw_model_bbox
=
f_draw_model_bbox
,
f_draw_line_sort_bbox
=
f_draw_line_sort_bbox
,
f_draw_char_bbox
=
f_draw_char_bbox
,
lang
=
lang
)
_do_parse
(
output_dir
,
pdf_file_names
[
idx
],
dss
[
idx
],
infer_result
.
get_infer_res
(),
parse_method
,
debug_able
,
f_draw_span_bbox
=
f_draw_span_bbox
,
f_draw_layout_bbox
=
f_draw_layout_bbox
,
f_dump_md
=
f_dump_md
,
f_dump_middle_json
=
f_dump_middle_json
,
f_dump_model_json
=
f_dump_model_json
,
f_dump_orig_pdf
=
f_dump_orig_pdf
,
f_dump_content_list
=
f_dump_content_list
,
f_make_md_mode
=
f_make_md_mode
,
f_draw_model_bbox
=
f_draw_model_bbox
,
f_draw_line_sort_bbox
=
f_draw_line_sort_bbox
,
f_draw_char_bbox
=
f_draw_char_bbox
,
lang
=
lang
)
parse_pdf_methods
=
click
.
Choice
([
'ocr'
,
'txt'
,
'auto'
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment