Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
f6bc4f70
Unverified
Commit
f6bc4f70
authored
Mar 26, 2025
by
Xiaomeng Zhao
Committed by
GitHub
Mar 26, 2025
Browse files
Merge pull request #2003 from icecraft/feat/batch_analyze_with_ocr_and_lang
feat: batch inference with ocr and lang flag
parents
2c8470b0
bbba2a12
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
60 deletions
+47
-60
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+20
-8
magic_pdf/model/doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+24
-38
magic_pdf/tools/common.py
magic_pdf/tools/common.py
+3
-14
No files found.
magic_pdf/model/batch_analyze.py
View file @
f6bc4f70
...
@@ -17,13 +17,25 @@ MFR_BASE_BATCH_SIZE = 16
...
@@ -17,13 +17,25 @@ MFR_BASE_BATCH_SIZE = 16
class
BatchAnalyze
:
class
BatchAnalyze
:
def
__init__
(
self
,
model
:
CustomPEKModel
,
batch_ratio
:
int
):
def
__init__
(
self
,
model
_manager
,
batch_ratio
:
int
,
show_log
,
layout_model
,
formula_enable
,
table_enable
):
self
.
model
=
model
self
.
model
_manager
=
model_manager
self
.
batch_ratio
=
batch_ratio
self
.
batch_ratio
=
batch_ratio
self
.
show_log
=
show_log
def
__call__
(
self
,
images
:
list
)
->
list
:
self
.
layout_model
=
layout_model
self
.
formula_enable
=
formula_enable
self
.
table_enable
=
table_enable
def
__call__
(
self
,
images_with_extra_info
:
list
)
->
list
:
if
len
(
images_with_extra_info
)
==
0
:
return
[]
images_layout_res
=
[]
images_layout_res
=
[]
layout_start_time
=
time
.
time
()
layout_start_time
=
time
.
time
()
_
,
fst_ocr
,
fst_lang
=
images_with_extra_info
[
0
]
self
.
model
=
self
.
model_manager
.
get_model
(
fst_ocr
,
self
.
show_log
,
fst_lang
,
self
.
layout_model
,
self
.
formula_enable
,
self
.
table_enable
)
images
=
[
image
for
image
,
_
,
_
in
images_with_extra_info
]
if
self
.
model
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
if
self
.
model
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
# layoutlmv3
# layoutlmv3
for
image
in
images
:
for
image
in
images
:
...
@@ -79,6 +91,8 @@ class BatchAnalyze:
...
@@ -79,6 +91,8 @@ class BatchAnalyze:
table_count
=
0
table_count
=
0
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for
index
in
range
(
len
(
images
)):
for
index
in
range
(
len
(
images
)):
_
,
ocr_enable
,
_lang
=
images_with_extra_info
[
index
]
self
.
model
=
self
.
model_manager
.
get_model
(
ocr_enable
,
self
.
show_log
,
_lang
,
self
.
layout_model
,
self
.
formula_enable
,
self
.
table_enable
)
layout_res
=
images_layout_res
[
index
]
layout_res
=
images_layout_res
[
index
]
np_array_img
=
images
[
index
]
np_array_img
=
images
[
index
]
...
@@ -99,7 +113,7 @@ class BatchAnalyze:
...
@@ -99,7 +113,7 @@ class BatchAnalyze:
# OCR recognition
# OCR recognition
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
if
self
.
model
.
apply_ocr
:
if
ocr_enable
:
ocr_res
=
self
.
model
.
ocr_model
.
ocr
(
ocr_res
=
self
.
model
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
)[
0
]
...
@@ -159,9 +173,7 @@ class BatchAnalyze:
...
@@ -159,9 +173,7 @@ class BatchAnalyze:
table_count
+=
len
(
table_res_list
)
table_count
+=
len
(
table_res_list
)
if
self
.
model
.
apply_ocr
:
if
self
.
model
.
apply_ocr
:
logger
.
info
(
f
'ocr time:
{
round
(
ocr_time
,
2
)
}
, image num:
{
ocr_count
}
'
)
logger
.
info
(
f
'det or det time costs:
{
round
(
ocr_time
,
2
)
}
, image num:
{
ocr_count
}
'
)
else
:
logger
.
info
(
f
'det time:
{
round
(
ocr_time
,
2
)
}
, image num:
{
ocr_count
}
'
)
if
self
.
model
.
apply_table
:
if
self
.
model
.
apply_table
:
logger
.
info
(
f
'table time:
{
round
(
table_time
,
2
)
}
, image num:
{
table_count
}
'
)
logger
.
info
(
f
'table time:
{
round
(
table_time
,
2
)
}
, image num:
{
table_count
}
'
)
...
...
magic_pdf/model/doc_analyze_by_custom_model.py
View file @
f6bc4f70
...
@@ -15,7 +15,7 @@ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
...
@@ -15,7 +15,7 @@ os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from
loguru
import
logger
from
loguru
import
logger
from
magic_pdf.model.sub_modules.model_utils
import
get_vram
from
magic_pdf.model.sub_modules.model_utils
import
get_vram
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
import
magic_pdf.model
as
model_config
import
magic_pdf.model
as
model_config
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.libs.clean_memory
import
clean_memory
...
@@ -150,12 +150,13 @@ def doc_analyze(
...
@@ -150,12 +150,13 @@ def doc_analyze(
img_dict
=
page_data
.
get_image
()
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
'img'
])
images
.
append
(
img_dict
[
'img'
])
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
images_with_extra_info
=
[(
images
[
index
],
ocr
,
dataset
.
_lang
)
for
index
in
range
(
len
(
dataset
))]
if
len
(
images
)
>=
MIN_BATCH_INFERENCE_SIZE
:
if
len
(
images
)
>=
MIN_BATCH_INFERENCE_SIZE
:
batch_size
=
MIN_BATCH_INFERENCE_SIZE
batch_size
=
MIN_BATCH_INFERENCE_SIZE
batch_images
=
[
images
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images
),
batch_size
)]
batch_images
=
[
images
_with_extra_info
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images
_with_extra_info
),
batch_size
)]
else
:
else
:
batch_images
=
[
images
]
batch_images
=
[
images
_with_extra_info
]
results
=
[]
results
=
[]
for
sn
,
batch_image
in
enumerate
(
batch_images
):
for
sn
,
batch_image
in
enumerate
(
batch_images
):
...
@@ -181,7 +182,7 @@ def doc_analyze(
...
@@ -181,7 +182,7 @@ def doc_analyze(
def
batch_doc_analyze
(
def
batch_doc_analyze
(
datasets
:
list
[
Dataset
],
datasets
:
list
[
Dataset
],
ocr
:
bool
=
False
,
parse_method
:
str
,
show_log
:
bool
=
False
,
show_log
:
bool
=
False
,
lang
=
None
,
lang
=
None
,
layout_model
=
None
,
layout_model
=
None
,
...
@@ -192,47 +193,31 @@ def batch_doc_analyze(
...
@@ -192,47 +193,31 @@ def batch_doc_analyze(
batch_size
=
MIN_BATCH_INFERENCE_SIZE
batch_size
=
MIN_BATCH_INFERENCE_SIZE
images
=
[]
images
=
[]
page_wh_list
=
[]
page_wh_list
=
[]
lang_list
=
[]
lang_s
=
set
()
images_with_extra_info
=
[]
for
dataset
in
datasets
:
for
dataset
in
datasets
:
for
index
in
range
(
len
(
dataset
)):
for
index
in
range
(
len
(
dataset
)):
if
lang
is
None
or
lang
==
'auto'
:
if
lang
is
None
or
lang
==
'auto'
:
lang_list
.
append
(
dataset
.
_lang
)
_lang
=
dataset
.
_lang
else
:
else
:
lang_list
.
append
(
lang
)
_lang
=
lang
lang_s
.
add
(
lang_list
[
-
1
])
page_data
=
dataset
.
get_page
(
index
)
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
'img'
])
images
.
append
(
img_dict
[
'img'
])
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
if
parse_method
==
'auto'
:
images_with_extra_info
.
append
((
images
[
-
1
],
dataset
.
classify
()
==
SupportedPdfParseMethod
.
OCR
,
_lang
))
else
:
images_with_extra_info
.
append
((
images
[
-
1
],
parse_method
==
'ocr'
,
_lang
))
batch_images
=
[]
batch_images
=
[
images_with_extra_info
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images_with_extra_info
),
batch_size
)]
img_idx_list
=
[]
results
=
[]
for
t_lang
in
lang_s
:
for
sn
,
batch_image
in
enumerate
(
batch_images
):
tmp_img_idx_list
=
[]
_
,
result
=
may_batch_image_analyze
(
batch_image
,
sn
,
True
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
for
i
,
_lang
in
enumerate
(
lang_list
):
results
.
extend
(
result
)
if
_lang
==
t_lang
:
tmp_img_idx_list
.
append
(
i
)
img_idx_list
.
extend
(
tmp_img_idx_list
)
if
batch_size
>=
len
(
tmp_img_idx_list
):
batch_images
.
append
((
t_lang
,
[
images
[
j
]
for
j
in
tmp_img_idx_list
]))
else
:
slices
=
[
tmp_img_idx_list
[
k
:
k
+
batch_size
]
for
k
in
range
(
0
,
len
(
tmp_img_idx_list
),
batch_size
)]
for
arr
in
slices
:
batch_images
.
append
((
t_lang
,
[
images
[
j
]
for
j
in
arr
]))
unorder_results
=
[]
for
sn
,
(
_lang
,
batch_image
)
in
enumerate
(
batch_images
):
_
,
result
=
may_batch_image_analyze
(
batch_image
,
sn
,
ocr
,
show_log
,
_lang
,
layout_model
,
formula_enable
,
table_enable
)
unorder_results
.
extend
(
result
)
results
=
[
None
]
*
len
(
img_idx_list
)
for
i
,
idx
in
enumerate
(
img_idx_list
):
results
[
idx
]
=
unorder_results
[
i
]
infer_results
=
[]
infer_results
=
[]
from
magic_pdf.operators.models
import
InferenceResult
from
magic_pdf.operators.models
import
InferenceResult
for
index
in
range
(
len
(
datasets
)):
for
index
in
range
(
len
(
datasets
)):
dataset
=
datasets
[
index
]
dataset
=
datasets
[
index
]
...
@@ -248,9 +233,9 @@ def batch_doc_analyze(
...
@@ -248,9 +233,9 @@ def batch_doc_analyze(
def
may_batch_image_analyze
(
def
may_batch_image_analyze
(
images
:
list
[
np
.
ndarray
],
images
_with_extra_info
:
list
[
(
np
.
ndarray
,
bool
,
str
)
],
idx
:
int
,
idx
:
int
,
ocr
:
bool
=
False
,
ocr
:
bool
,
show_log
:
bool
=
False
,
show_log
:
bool
=
False
,
lang
=
None
,
lang
=
None
,
layout_model
=
None
,
layout_model
=
None
,
...
@@ -267,6 +252,7 @@ def may_batch_image_analyze(
...
@@ -267,6 +252,7 @@ def may_batch_image_analyze(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
)
images
=
[
image
for
image
,
_
,
_
in
images_with_extra_info
]
batch_analyze
=
False
batch_analyze
=
False
batch_ratio
=
1
batch_ratio
=
1
device
=
get_device
()
device
=
get_device
()
...
@@ -306,8 +292,8 @@ def may_batch_image_analyze(
...
@@ -306,8 +292,8 @@ def may_batch_image_analyze(
images.append(img_dict['img'])
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
page_wh_list.append((img_dict['width'], img_dict['height']))
"""
"""
batch_model
=
BatchAnalyze
(
model
=
custom_model
,
batch_ratio
=
batch_ratio
)
batch_model
=
BatchAnalyze
(
model
_manager
,
batch_ratio
,
show_log
,
layout_model
,
formula_enable
,
table_enable
)
results
=
batch_model
(
images
)
results
=
batch_model
(
images
_with_extra_info
)
"""
"""
for index in range(len(dataset)):
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
if start_page_id <= index <= end_page_id:
...
...
magic_pdf/tools/common.py
View file @
f6bc4f70
...
@@ -314,21 +314,10 @@ def batch_do_parse(
...
@@ -314,21 +314,10 @@ def batch_do_parse(
dss
.
append
(
PymuDocDataset
(
v
,
lang
=
lang
))
dss
.
append
(
PymuDocDataset
(
v
,
lang
=
lang
))
else
:
else
:
dss
.
append
(
v
)
dss
.
append
(
v
)
dss_with_fn
=
list
(
zip
(
dss
,
pdf_file_names
))
if
parse_method
==
'auto'
:
infer_results
=
batch_doc_analyze
(
dss
,
parse_method
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
dss_typed_txt
=
[(
i
,
x
)
for
i
,
x
in
enumerate
(
dss_with_fn
)
if
x
[
0
].
classify
()
==
SupportedPdfParseMethod
.
TXT
]
dss_typed_ocr
=
[(
i
,
x
)
for
i
,
x
in
enumerate
(
dss_with_fn
)
if
x
[
0
].
classify
()
==
SupportedPdfParseMethod
.
OCR
]
infer_results
=
[
None
]
*
len
(
dss_with_fn
)
infer_results_txt
=
batch_doc_analyze
([
x
[
1
][
0
]
for
x
in
dss_typed_txt
],
lang
=
lang
,
ocr
=
False
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
infer_results_ocr
=
batch_doc_analyze
([
x
[
1
][
0
]
for
x
in
dss_typed_ocr
],
lang
=
lang
,
ocr
=
True
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
for
i
,
infer_res
in
enumerate
(
infer_results_txt
):
infer_results
[
dss_typed_txt
[
i
][
0
]]
=
infer_res
for
i
,
infer_res
in
enumerate
(
infer_results_ocr
):
infer_results
[
dss_typed_ocr
[
i
][
0
]]
=
infer_res
else
:
infer_results
=
batch_doc_analyze
(
dss
,
lang
=
lang
,
ocr
=
parse_method
==
'ocr'
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
for
idx
,
infer_result
in
enumerate
(
infer_results
):
for
idx
,
infer_result
in
enumerate
(
infer_results
):
_do_parse
(
output_dir
,
dss_with_fn
[
idx
]
[
1
]
,
dss
_with_fn
[
idx
]
[
0
]
,
infer_result
.
get_infer_res
(),
parse_method
,
debug_able
,
f_draw_span_bbox
=
f_draw_span_bbox
,
f_draw_layout_bbox
=
f_draw_layout_bbox
,
f_dump_md
=
f_dump_md
,
f_dump_middle_json
=
f_dump_middle_json
,
f_dump_model_json
=
f_dump_model_json
,
f_dump_orig_pdf
=
f_dump_orig_pdf
,
f_dump_content_list
=
f_dump_content_list
,
f_make_md_mode
=
f_make_md_mode
,
f_draw_model_bbox
=
f_draw_model_bbox
,
f_draw_line_sort_bbox
=
f_draw_line_sort_bbox
,
f_draw_char_bbox
=
f_draw_char_bbox
,
lang
=
lang
)
_do_parse
(
output_dir
,
pdf_file_names
[
idx
],
dss
[
idx
],
infer_result
.
get_infer_res
(),
parse_method
,
debug_able
,
f_draw_span_bbox
=
f_draw_span_bbox
,
f_draw_layout_bbox
=
f_draw_layout_bbox
,
f_dump_md
=
f_dump_md
,
f_dump_middle_json
=
f_dump_middle_json
,
f_dump_model_json
=
f_dump_model_json
,
f_dump_orig_pdf
=
f_dump_orig_pdf
,
f_dump_content_list
=
f_dump_content_list
,
f_make_md_mode
=
f_make_md_mode
,
f_draw_model_bbox
=
f_draw_model_bbox
,
f_draw_line_sort_bbox
=
f_draw_line_sort_bbox
,
f_draw_char_bbox
=
f_draw_char_bbox
,
lang
=
lang
)
parse_pdf_methods
=
click
.
Choice
([
'ocr'
,
'txt'
,
'auto'
])
parse_pdf_methods
=
click
.
Choice
([
'ocr'
,
'txt'
,
'auto'
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment