Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
b2887ca0
"build_tools/vscode:/vscode.git/clone" did not exist on "2a7164681d35dcdfe0b8d516ed5ab37d18c38127"
Commit
b2887ca0
authored
Dec 18, 2024
by
icecraft
Browse files
refactor: refactor code
parent
303a4b01
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
56 additions
and
166 deletions
+56
-166
magic_pdf/model/__init__.py
magic_pdf/model/__init__.py
+0
-101
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+31
-37
magic_pdf/model/doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+3
-3
magic_pdf/operators/__init__.py
magic_pdf/operators/__init__.py
+0
-0
magic_pdf/operators/models.py
magic_pdf/operators/models.py
+2
-4
magic_pdf/operators/pipes.py
magic_pdf/operators/pipes.py
+0
-0
magic_pdf/tools/common.py
magic_pdf/tools/common.py
+3
-3
next_docs/en/api/model_operators.rst
next_docs/en/api/model_operators.rst
+1
-1
next_docs/en/api/pipe_operators.rst
next_docs/en/api/pipe_operators.rst
+2
-2
next_docs/en/user_guide/inference_result.rst
next_docs/en/user_guide/inference_result.rst
+7
-8
next_docs/en/user_guide/pipe_result.rst
next_docs/en/user_guide/pipe_result.rst
+7
-7
No files found.
magic_pdf/model/__init__.py
View file @
b2887ca0
from
typing
import
Callable
from
abc
import
ABC
,
abstractmethod
from
magic_pdf.data.data_reader_writer
import
DataWriter
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.pipe.operators
import
PipeResult
__use_inside_model__
=
True
__model_mode__
=
"full"
class
InferenceResultBase
(
ABC
):
@
abstractmethod
def
__init__
(
self
,
inference_results
:
list
,
dataset
:
Dataset
):
"""Initialized method.
Args:
inference_results (list): the inference result generated by model
dataset (Dataset): the dataset related with model inference result
"""
self
.
_infer_res
=
inference_results
self
.
_dataset
=
dataset
@
abstractmethod
def
draw_model
(
self
,
file_path
:
str
)
->
None
:
"""Draw model inference result.
Args:
file_path (str): the output file path
"""
pass
@
abstractmethod
def
dump_model
(
self
,
writer
:
DataWriter
,
file_path
:
str
):
"""Dump model inference result to file.
Args:
writer (DataWriter): writer handle
file_path (str): the location of target file
"""
pass
@
abstractmethod
def
get_infer_res
(
self
):
"""Get the inference result.
Returns:
list: the inference result generated by model
"""
pass
@
abstractmethod
def
apply
(
self
,
proc
:
Callable
,
*
args
,
**
kwargs
):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(inference_result, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
pass
@
abstractmethod
def
pipe_txt_mode
(
self
,
imageWriter
:
DataWriter
,
start_page_id
=
0
,
end_page_id
=
None
,
debug_mode
=
False
,
lang
=
None
,
)
->
PipeResult
:
"""Post-proc the model inference result, Extract the text using the
third library, such as `pymupdf`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pass
@
abstractmethod
def
pipe_ocr_mode
(
self
,
imageWriter
:
DataWriter
,
start_page_id
=
0
,
end_page_id
=
None
,
debug_mode
=
False
,
lang
=
None
,
)
->
PipeResult
:
pass
magic_pdf/model/batch_analyze.py
View file @
b2887ca0
...
@@ -11,17 +11,12 @@ from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
...
@@ -11,17 +11,12 @@ from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.model.doc_analyze_by_custom_model
import
ModelSingleton
from
magic_pdf.model.doc_analyze_by_custom_model
import
ModelSingleton
from
magic_pdf.model.operators
import
InferenceResult
from
magic_pdf.model.pdf_extract_kit
import
CustomPEKModel
from
magic_pdf.model.pdf_extract_kit
import
CustomPEKModel
from
magic_pdf.model.sub_modules.model_utils
import
(
from
magic_pdf.model.sub_modules.model_utils
import
(
clean_vram
,
clean_vram
,
crop_img
,
get_res_list_from_layout_res
)
crop_img
,
get_res_list_from_layout_res
,
)
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
get_adjusted_mfdetrec_res
,
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
get_ocr_result_list
,
from
magic_pdf.operators.models
import
InferenceResult
)
YOLO_LAYOUT_BASE_BATCH_SIZE
=
4
YOLO_LAYOUT_BASE_BATCH_SIZE
=
4
MFD_BASE_BATCH_SIZE
=
1
MFD_BASE_BATCH_SIZE
=
1
...
@@ -50,7 +45,7 @@ class BatchAnalyze:
...
@@ -50,7 +45,7 @@ class BatchAnalyze:
pil_img
=
Image
.
fromarray
(
image
)
pil_img
=
Image
.
fromarray
(
image
)
width
,
height
=
pil_img
.
size
width
,
height
=
pil_img
.
size
if
height
>
width
:
if
height
>
width
:
input_res
=
{
"
poly
"
:
[
0
,
0
,
width
,
0
,
width
,
height
,
0
,
height
]}
input_res
=
{
'
poly
'
:
[
0
,
0
,
width
,
0
,
width
,
height
,
0
,
height
]}
new_image
,
useful_list
=
crop_img
(
new_image
,
useful_list
=
crop_img
(
input_res
,
pil_img
,
crop_paste_x
=
width
//
2
,
crop_paste_y
=
0
input_res
,
pil_img
,
crop_paste_x
=
width
//
2
,
crop_paste_y
=
0
)
)
...
@@ -65,17 +60,17 @@ class BatchAnalyze:
...
@@ -65,17 +60,17 @@ class BatchAnalyze:
for
image_index
,
useful_list
in
modified_images
:
for
image_index
,
useful_list
in
modified_images
:
for
res
in
images_layout_res
[
image_index
]:
for
res
in
images_layout_res
[
image_index
]:
for
i
in
range
(
len
(
res
[
"
poly
"
])):
for
i
in
range
(
len
(
res
[
'
poly
'
])):
if
i
%
2
==
0
:
if
i
%
2
==
0
:
res
[
"
poly
"
][
i
]
=
(
res
[
'
poly
'
][
i
]
=
(
res
[
"
poly
"
][
i
]
-
useful_list
[
0
]
+
useful_list
[
2
]
res
[
'
poly
'
][
i
]
-
useful_list
[
0
]
+
useful_list
[
2
]
)
)
else
:
else
:
res
[
"
poly
"
][
i
]
=
(
res
[
'
poly
'
][
i
]
=
(
res
[
"
poly
"
][
i
]
-
useful_list
[
1
]
+
useful_list
[
3
]
res
[
'
poly
'
][
i
]
-
useful_list
[
1
]
+
useful_list
[
3
]
)
)
logger
.
info
(
logger
.
info
(
f
"
layout time:
{
round
(
time
.
time
()
-
layout_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
"
f
'
layout time:
{
round
(
time
.
time
()
-
layout_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
'
)
)
if
self
.
model
.
apply_formula
:
if
self
.
model
.
apply_formula
:
...
@@ -85,7 +80,7 @@ class BatchAnalyze:
...
@@ -85,7 +80,7 @@ class BatchAnalyze:
images
,
self
.
batch_ratio
*
MFD_BASE_BATCH_SIZE
images
,
self
.
batch_ratio
*
MFD_BASE_BATCH_SIZE
)
)
logger
.
info
(
logger
.
info
(
f
"
mfd time:
{
round
(
time
.
time
()
-
mfd_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
"
f
'
mfd time:
{
round
(
time
.
time
()
-
mfd_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
'
)
)
# 公式识别
# 公式识别
...
@@ -98,7 +93,7 @@ class BatchAnalyze:
...
@@ -98,7 +93,7 @@ class BatchAnalyze:
for
image_index
in
range
(
len
(
images
)):
for
image_index
in
range
(
len
(
images
)):
images_layout_res
[
image_index
]
+=
images_formula_list
[
image_index
]
images_layout_res
[
image_index
]
+=
images_formula_list
[
image_index
]
logger
.
info
(
logger
.
info
(
f
"
mfr time:
{
round
(
time
.
time
()
-
mfr_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
"
f
'
mfr time:
{
round
(
time
.
time
()
-
mfr_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
'
)
)
# 清理显存
# 清理显存
...
@@ -156,7 +151,7 @@ class BatchAnalyze:
...
@@ -156,7 +151,7 @@ class BatchAnalyze:
if
self
.
model
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
if
self
.
model
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
table_result
=
self
.
model
.
table_model
.
predict
(
table_result
=
self
.
model
.
table_model
.
predict
(
new_image
,
"
html
"
new_image
,
'
html
'
)
)
if
len
(
table_result
)
>
0
:
if
len
(
table_result
)
>
0
:
html_code
=
table_result
[
0
]
html_code
=
table_result
[
0
]
...
@@ -169,32 +164,32 @@ class BatchAnalyze:
...
@@ -169,32 +164,32 @@ class BatchAnalyze:
run_time
=
time
.
time
()
-
single_table_start_time
run_time
=
time
.
time
()
-
single_table_start_time
if
run_time
>
self
.
model
.
table_max_time
:
if
run_time
>
self
.
model
.
table_max_time
:
logger
.
warning
(
logger
.
warning
(
f
"
table recognition processing exceeds max time
{
self
.
model
.
table_max_time
}
s
"
f
'
table recognition processing exceeds max time
{
self
.
model
.
table_max_time
}
s
'
)
)
# 判断是否返回正常
# 判断是否返回正常
if
html_code
:
if
html_code
:
expected_ending
=
html_code
.
strip
().
endswith
(
expected_ending
=
html_code
.
strip
().
endswith
(
"
</html>
"
'
</html>
'
)
or
html_code
.
strip
().
endswith
(
"
</table>
"
)
)
or
html_code
.
strip
().
endswith
(
'
</table>
'
)
if
expected_ending
:
if
expected_ending
:
res
[
"
html
"
]
=
html_code
res
[
'
html
'
]
=
html_code
else
:
else
:
logger
.
warning
(
logger
.
warning
(
"
table recognition processing fails, not found expected HTML table end
"
'
table recognition processing fails, not found expected HTML table end
'
)
)
else
:
else
:
logger
.
warning
(
logger
.
warning
(
"
table recognition processing fails, not get html return
"
'
table recognition processing fails, not get html return
'
)
)
table_time
+=
time
.
time
()
-
table_start
table_time
+=
time
.
time
()
-
table_start
table_count
+=
len
(
table_res_list
)
table_count
+=
len
(
table_res_list
)
if
self
.
model
.
apply_ocr
:
if
self
.
model
.
apply_ocr
:
logger
.
info
(
f
"
ocr time:
{
round
(
ocr_time
,
2
)
}
, image num:
{
ocr_count
}
"
)
logger
.
info
(
f
'
ocr time:
{
round
(
ocr_time
,
2
)
}
, image num:
{
ocr_count
}
'
)
else
:
else
:
logger
.
info
(
f
"
det time:
{
round
(
ocr_time
,
2
)
}
, image num:
{
ocr_count
}
"
)
logger
.
info
(
f
'
det time:
{
round
(
ocr_time
,
2
)
}
, image num:
{
ocr_count
}
'
)
if
self
.
model
.
apply_table
:
if
self
.
model
.
apply_table
:
logger
.
info
(
f
"
table time:
{
round
(
table_time
,
2
)
}
, image num:
{
table_count
}
"
)
logger
.
info
(
f
'
table time:
{
round
(
table_time
,
2
)
}
, image num:
{
table_count
}
'
)
return
images_layout_res
return
images_layout_res
...
@@ -211,8 +206,7 @@ def doc_batch_analyze(
...
@@ -211,8 +206,7 @@ def doc_batch_analyze(
table_enable
=
None
,
table_enable
=
None
,
batch_ratio
:
int
|
None
=
None
,
batch_ratio
:
int
|
None
=
None
,
)
->
InferenceResult
:
)
->
InferenceResult
:
"""
"""Perform batch analysis on a document dataset.
Perform batch analysis on a document dataset.
Args:
Args:
dataset (Dataset): The dataset containing document pages to be analyzed.
dataset (Dataset): The dataset containing document pages to be analyzed.
...
@@ -234,9 +228,9 @@ def doc_batch_analyze(
...
@@ -234,9 +228,9 @@ def doc_batch_analyze(
"""
"""
if
not
torch
.
cuda
.
is_available
():
if
not
torch
.
cuda
.
is_available
():
raise
CUDA_NOT_AVAILABLE
(
"
batch analyze not support in CPU mode
"
)
raise
CUDA_NOT_AVAILABLE
(
'
batch analyze not support in CPU mode
'
)
lang
=
None
if
lang
==
""
else
lang
lang
=
None
if
lang
==
''
else
lang
# TODO: auto detect batch size
# TODO: auto detect batch size
batch_ratio
=
1
if
batch_ratio
is
None
else
batch_ratio
batch_ratio
=
1
if
batch_ratio
is
None
else
batch_ratio
end_page_id
=
end_page_id
if
end_page_id
else
len
(
dataset
)
end_page_id
=
end_page_id
if
end_page_id
else
len
(
dataset
)
...
@@ -255,26 +249,26 @@ def doc_batch_analyze(
...
@@ -255,26 +249,26 @@ def doc_batch_analyze(
if
start_page_id
<=
index
<=
end_page_id
:
if
start_page_id
<=
index
<=
end_page_id
:
page_data
=
dataset
.
get_page
(
index
)
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
"
img
"
])
images
.
append
(
img_dict
[
'
img
'
])
analyze_result
=
batch_model
(
images
)
analyze_result
=
batch_model
(
images
)
for
index
in
range
(
len
(
dataset
)):
for
index
in
range
(
len
(
dataset
)):
page_data
=
dataset
.
get_page
(
index
)
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
img_dict
=
page_data
.
get_image
()
page_width
=
img_dict
[
"
width
"
]
page_width
=
img_dict
[
'
width
'
]
page_height
=
img_dict
[
"
height
"
]
page_height
=
img_dict
[
'
height
'
]
if
start_page_id
<=
index
<=
end_page_id
:
if
start_page_id
<=
index
<=
end_page_id
:
result
=
analyze_result
.
pop
(
0
)
result
=
analyze_result
.
pop
(
0
)
else
:
else
:
result
=
[]
result
=
[]
page_info
=
{
"
page_no
"
:
index
,
"
height
"
:
page_height
,
"
width
"
:
page_width
}
page_info
=
{
'
page_no
'
:
index
,
'
height
'
:
page_height
,
'
width
'
:
page_width
}
page_dict
=
{
"
layout_dets
"
:
result
,
"
page_info
"
:
page_info
}
page_dict
=
{
'
layout_dets
'
:
result
,
'
page_info
'
:
page_info
}
model_json
.
append
(
page_dict
)
model_json
.
append
(
page_dict
)
# TODO: clean memory when gpu memory is not enough
# TODO: clean memory when gpu memory is not enough
clean_memory_start_time
=
time
.
time
()
clean_memory_start_time
=
time
.
time
()
clean_memory
()
clean_memory
()
logger
.
info
(
f
"
clean memory time:
{
round
(
time
.
time
()
-
clean_memory_start_time
,
2
)
}
"
)
logger
.
info
(
f
'
clean memory time:
{
round
(
time
.
time
()
-
clean_memory_start_time
,
2
)
}
'
)
return
InferenceResult
(
model_json
,
dataset
)
return
InferenceResult
(
model_json
,
dataset
)
magic_pdf/model/doc_analyze_by_custom_model.py
View file @
b2887ca0
import
os
import
os
import
time
import
time
from
loguru
import
logger
# 关闭paddle的信号处理
# 关闭paddle的信号处理
import
paddle
import
paddle
from
loguru
import
logger
paddle
.
disable_signal_handler
()
paddle
.
disable_signal_handler
()
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
...
@@ -25,7 +25,7 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
...
@@ -25,7 +25,7 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_local_models_dir
,
get_local_models_dir
,
get_table_recog_config
)
get_table_recog_config
)
from
magic_pdf.model.model_list
import
MODEL
from
magic_pdf.model.model_list
import
MODEL
from
magic_pdf.
model.
operators
import
InferenceResult
from
magic_pdf.operators
.models
import
InferenceResult
def
dict_compare
(
d1
,
d2
):
def
dict_compare
(
d1
,
d2
):
...
...
magic_pdf/operators/__init__.py
0 → 100644
View file @
b2887ca0
magic_pdf/
model/
operators.py
→
magic_pdf/operators
/models
.py
View file @
b2887ca0
...
@@ -7,15 +7,13 @@ from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT
...
@@ -7,15 +7,13 @@ from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
from
magic_pdf.data.data_reader_writer
import
DataWriter
from
magic_pdf.data.data_reader_writer
import
DataWriter
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.filter
import
classify
from
magic_pdf.libs.draw_bbox
import
draw_model_bbox
from
magic_pdf.libs.draw_bbox
import
draw_model_bbox
from
magic_pdf.libs.version
import
__version__
from
magic_pdf.libs.version
import
__version__
from
magic_pdf.
model
import
Inferenc
eResult
Base
from
magic_pdf.
operators.pipes
import
Pip
eResult
from
magic_pdf.pdf_parse_union_core_v2
import
pdf_parse_union
from
magic_pdf.pdf_parse_union_core_v2
import
pdf_parse_union
from
magic_pdf.pipe.operators
import
PipeResult
class
InferenceResult
(
InferenceResultBase
)
:
class
InferenceResult
:
def
__init__
(
self
,
inference_results
:
list
,
dataset
:
Dataset
):
def
__init__
(
self
,
inference_results
:
list
,
dataset
:
Dataset
):
"""Initialized method.
"""Initialized method.
...
...
magic_pdf/
pipe/
operators.py
→
magic_pdf/operators
/pipes
.py
View file @
b2887ca0
File moved
magic_pdf/tools/common.py
View file @
b2887ca0
...
@@ -10,7 +10,7 @@ from magic_pdf.config.make_content_config import DropMode, MakeMode
...
@@ -10,7 +10,7 @@ from magic_pdf.config.make_content_config import DropMode, MakeMode
from
magic_pdf.data.data_reader_writer
import
FileBasedDataWriter
from
magic_pdf.data.data_reader_writer
import
FileBasedDataWriter
from
magic_pdf.data.dataset
import
PymuDocDataset
from
magic_pdf.data.dataset
import
PymuDocDataset
from
magic_pdf.model.doc_analyze_by_custom_model
import
doc_analyze
from
magic_pdf.model.doc_analyze_by_custom_model
import
doc_analyze
from
magic_pdf.
model.
operators
import
InferenceResult
from
magic_pdf.operators
.models
import
InferenceResult
# from io import BytesIO
# from io import BytesIO
# from pypdf import PdfReader, PdfWriter
# from pypdf import PdfReader, PdfWriter
...
...
next_docs/en/api/model_operators.rst
View file @
b2887ca0
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
Model Api
Model Api
==========
==========
.. autoclass:: magic_pdf.model.InferenceResult
Base
.. autoclass:: magic_pdf.
operators.
model
s
.InferenceResult
:members:
:members:
:inherited-members:
:inherited-members:
:show-inheritance:
:show-inheritance:
next_docs/en/api/pipe_operators.rst
View file @
b2887ca0
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
Pipeline Api
Pipeline Api
=============
=============
.. autoclass:: magic_pdf.
pipe.
operators.PipeResult
.. autoclass:: magic_pdf.operators.
pipes.
PipeResult
:members:
:members:
:inherited-members:
:inherited-members:
:show-inheritance:
:show-inheritance:
next_docs/en/user_guide/inference_result.rst
View file @
b2887ca0
...
@@ -122,7 +122,7 @@ Inference Result
...
@@ -122,7 +122,7 @@ Inference Result
.. code:: python
.. code:: python
from magic_pdf.
model.
operators import InferenceResult
from magic_pdf.operators
.models
import InferenceResult
from magic_pdf.data.dataset import Dataset
from magic_pdf.data.dataset import Dataset
dataset : Dataset = some_data_set # not real dataset
dataset : Dataset = some_data_set # not real dataset
...
@@ -142,4 +142,3 @@ some_model.pdf
...
@@ -142,4 +142,3 @@ some_model.pdf
.. |Poly Coordinate Diagram| image:: ../_static/image/poly.png
.. |Poly Coordinate Diagram| image:: ../_static/image/poly.png
next_docs/en/user_guide/pipe_result.rst
View file @
b2887ca0
...
@@ -294,7 +294,7 @@ Pipeline Result
...
@@ -294,7 +294,7 @@ Pipeline Result
.. code:: python
.. code:: python
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
from magic_pdf.
pipe.
operators import PipeResult
from magic_pdf.operators
.pipes
import PipeResult
from magic_pdf.data.dataset import Dataset
from magic_pdf.data.dataset import Dataset
res = pdf_parse_union(*args, **kwargs)
res = pdf_parse_union(*args, **kwargs)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment