Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
49bfdf07
Commit
49bfdf07
authored
Dec 13, 2024
by
Suven
Browse files
feat: enhance batch processing in BatchAnalyze with layout and OCR timing logs
parent
4fd1e41e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
8 deletions
+56
-8
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+56
-8
No files found.
magic_pdf/model/batch_analyze.py
View file @
49bfdf07
...
@@ -35,6 +35,8 @@ class BatchAnalyze:
...
@@ -35,6 +35,8 @@ class BatchAnalyze:
def
__call__
(
self
,
images
:
list
)
->
list
:
def
__call__
(
self
,
images
:
list
)
->
list
:
images_layout_res
=
[]
images_layout_res
=
[]
layout_start_time
=
time
.
time
()
if
self
.
model
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
if
self
.
model
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
# layoutlmv3
# layoutlmv3
for
image
in
images
:
for
image
in
images
:
...
@@ -42,17 +44,52 @@ class BatchAnalyze:
...
@@ -42,17 +44,52 @@ class BatchAnalyze:
images_layout_res
.
append
(
layout_res
)
images_layout_res
.
append
(
layout_res
)
elif
self
.
model
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
elif
self
.
model
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
# doclayout_yolo
# doclayout_yolo
layout_images
=
[]
modified_images
=
[]
for
image_index
,
image
in
enumerate
(
images
):
pil_img
=
Image
.
fromarray
(
image
)
width
,
height
=
pil_img
.
size
if
height
>
width
:
input_res
=
{
"poly"
:
[
0
,
0
,
width
,
0
,
width
,
height
,
0
,
height
]}
new_image
,
useful_list
=
crop_img
(
input_res
,
pil_img
,
crop_paste_x
=
width
//
2
,
crop_paste_y
=
0
)
layout_images
.
append
(
new_image
)
modified_images
.
append
([
image_index
,
useful_list
])
else
:
layout_images
.
append
(
pil_img
)
images_layout_res
+=
self
.
model
.
layout_model
.
batch_predict
(
images_layout_res
+=
self
.
model
.
layout_model
.
batch_predict
(
images
,
self
.
batch_ratio
*
YOLO_LAYOUT_BASE_BATCH_SIZE
layout_
images
,
self
.
batch_ratio
*
YOLO_LAYOUT_BASE_BATCH_SIZE
)
)
for
image_index
,
useful_list
in
modified_images
:
for
res
in
images_layout_res
[
image_index
]:
for
i
in
range
(
len
(
res
[
"poly"
])):
if
i
%
2
==
0
:
res
[
"poly"
][
i
]
=
(
res
[
"poly"
][
i
]
-
useful_list
[
0
]
+
useful_list
[
2
]
)
else
:
res
[
"poly"
][
i
]
=
(
res
[
"poly"
][
i
]
-
useful_list
[
1
]
+
useful_list
[
3
]
)
logger
.
info
(
f
"layout time:
{
round
(
time
.
time
()
-
layout_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
"
)
if
self
.
model
.
apply_formula
:
if
self
.
model
.
apply_formula
:
# 公式检测
# 公式检测
mfd_start_time
=
time
.
time
()
images_mfd_res
=
self
.
model
.
mfd_model
.
batch_predict
(
images_mfd_res
=
self
.
model
.
mfd_model
.
batch_predict
(
images
,
self
.
batch_ratio
*
MFD_BASE_BATCH_SIZE
images
,
self
.
batch_ratio
*
MFD_BASE_BATCH_SIZE
)
)
logger
.
info
(
f
"mfd time:
{
round
(
time
.
time
()
-
mfd_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
"
)
# 公式识别
# 公式识别
mfr_start_time
=
time
.
time
()
images_formula_list
=
self
.
model
.
mfr_model
.
batch_predict
(
images_formula_list
=
self
.
model
.
mfr_model
.
batch_predict
(
images_mfd_res
,
images_mfd_res
,
images
,
images
,
...
@@ -60,10 +97,17 @@ class BatchAnalyze:
...
@@ -60,10 +97,17 @@ class BatchAnalyze:
)
)
for
image_index
in
range
(
len
(
images
)):
for
image_index
in
range
(
len
(
images
)):
images_layout_res
[
image_index
]
+=
images_formula_list
[
image_index
]
images_layout_res
[
image_index
]
+=
images_formula_list
[
image_index
]
logger
.
info
(
f
"mfr time:
{
round
(
time
.
time
()
-
mfr_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
"
)
# 清理显存
# 清理显存
clean_vram
(
self
.
model
.
device
,
vram_threshold
=
8
)
clean_vram
(
self
.
model
.
device
,
vram_threshold
=
8
)
ocr_time
=
0
ocr_count
=
0
table_time
=
0
table_count
=
0
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for
index
in
range
(
len
(
images
)):
for
index
in
range
(
len
(
images
)):
layout_res
=
images_layout_res
[
index
]
layout_res
=
images_layout_res
[
index
]
...
@@ -99,12 +143,8 @@ class BatchAnalyze:
...
@@ -99,12 +143,8 @@ class BatchAnalyze:
if
ocr_res
:
if
ocr_res
:
ocr_result_list
=
get_ocr_result_list
(
ocr_res
,
useful_list
)
ocr_result_list
=
get_ocr_result_list
(
ocr_res
,
useful_list
)
layout_res
.
extend
(
ocr_result_list
)
layout_res
.
extend
(
ocr_result_list
)
ocr_time
+=
time
.
time
()
-
ocr_start
ocr_cost
=
round
(
time
.
time
()
-
ocr_start
,
2
)
ocr_count
+=
len
(
ocr_res_list
)
if
self
.
model
.
apply_ocr
:
logger
.
info
(
f
"ocr time:
{
ocr_cost
}
"
)
else
:
logger
.
info
(
f
"det time:
{
ocr_cost
}
"
)
# 表格识别 table recognition
# 表格识别 table recognition
if
self
.
model
.
apply_table
:
if
self
.
model
.
apply_table
:
...
@@ -146,7 +186,13 @@ class BatchAnalyze:
...
@@ -146,7 +186,13 @@ class BatchAnalyze:
logger
.
warning
(
logger
.
warning
(
"table recognition processing fails, not get html return"
"table recognition processing fails, not get html return"
)
)
logger
.
info
(
f
"table time:
{
round
(
time
.
time
()
-
table_start
,
2
)
}
"
)
table_time
+=
time
.
time
()
-
table_start
table_count
+=
len
(
table_res_list
)
if
self
.
model
.
apply_ocr
:
logger
.
info
(
f
"ocr time:
{
round
(
ocr_time
,
2
)
}
, image num:
{
ocr_count
}
"
)
if
self
.
model
.
apply_table
:
logger
.
info
(
f
"table time:
{
round
(
table_time
,
2
)
}
, image num:
{
table_count
}
"
)
return
images_layout_res
return
images_layout_res
...
@@ -225,6 +271,8 @@ def doc_batch_analyze(
...
@@ -225,6 +271,8 @@ def doc_batch_analyze(
model_json
.
append
(
page_dict
)
model_json
.
append
(
page_dict
)
# TODO: clean memory when gpu memory is not enough
# TODO: clean memory when gpu memory is not enough
clean_memory_start_time
=
time
.
time
()
clean_memory
()
clean_memory
()
logger
.
info
(
f
"clean memory time:
{
round
(
time
.
time
()
-
clean_memory_start_time
,
2
)
}
"
)
return
InferenceResult
(
model_json
,
dataset
)
return
InferenceResult
(
model_json
,
dataset
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment