Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
ecdd162f
Unverified
Commit
ecdd162f
authored
Mar 13, 2025
by
Xiaomeng Zhao
Committed by
GitHub
Mar 13, 2025
Browse files
Merge pull request #1910 from icecraft/fix/parallel_split
Fix/parallel split
parents
734ae27b
c67a4793
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
615 additions
and
82 deletions
+615
-82
magic_pdf/data/batch_build_dataset.py
magic_pdf/data/batch_build_dataset.py
+156
-0
magic_pdf/data/dataset.py
magic_pdf/data/dataset.py
+40
-23
magic_pdf/data/utils.py
magic_pdf/data/utils.py
+103
-0
magic_pdf/model/doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+179
-22
magic_pdf/model/sub_modules/model_init.py
magic_pdf/model/sub_modules/model_init.py
+28
-16
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py
+1
-0
magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py
...model/sub_modules/table/tablemaster/tablemaster_paddle.py
+1
-0
magic_pdf/tools/cli.py
magic_pdf/tools/cli.py
+19
-10
magic_pdf/tools/common.py
magic_pdf/tools/common.py
+88
-11
No files found.
magic_pdf/data/batch_build_dataset.py
0 → 100644
View file @
ecdd162f
import
concurrent.futures
import
fitz
from
magic_pdf.data.dataset
import
PymuDocDataset
from
magic_pdf.data.utils
import
fitz_doc_to_image
# PyMuPDF
def
partition_array_greedy
(
arr
,
k
):
"""Partition an array into k parts using a simple greedy approach.
Parameters:
-----------
arr : list
The input array of integers
k : int
Number of partitions to create
Returns:
--------
partitions : list of lists
The k partitions of the array
"""
# Handle edge cases
if
k
<=
0
:
raise
ValueError
(
'k must be a positive integer'
)
if
k
>
len
(
arr
):
k
=
len
(
arr
)
# Adjust k if it's too large
if
k
==
1
:
return
[
list
(
range
(
len
(
arr
)))]
if
k
==
len
(
arr
):
return
[[
i
]
for
i
in
range
(
len
(
arr
))]
# Sort the array in descending order
sorted_indices
=
sorted
(
range
(
len
(
arr
)),
key
=
lambda
i
:
arr
[
i
][
1
],
reverse
=
True
)
# Initialize k empty partitions
partitions
=
[[]
for
_
in
range
(
k
)]
partition_sums
=
[
0
]
*
k
# Assign each element to the partition with the smallest current sum
for
idx
in
sorted_indices
:
# Find the partition with the smallest sum
min_sum_idx
=
partition_sums
.
index
(
min
(
partition_sums
))
# Add the element to this partition
partitions
[
min_sum_idx
].
append
(
idx
)
# Store the original index
partition_sums
[
min_sum_idx
]
+=
arr
[
idx
][
1
]
return
partitions
def
process_pdf_batch
(
pdf_jobs
,
idx
):
"""Process a batch of PDF pages using multiple threads.
Parameters:
-----------
pdf_jobs : list of tuples
List of (pdf_path, page_num) tuples
output_dir : str or None
Directory to save images to
num_threads : int
Number of threads to use
**kwargs :
Additional arguments for process_pdf_page
Returns:
--------
images : list
List of processed images
"""
images
=
[]
for
pdf_path
,
_
in
pdf_jobs
:
doc
=
fitz
.
open
(
pdf_path
)
tmp
=
[]
for
page_num
in
range
(
len
(
doc
)):
page
=
doc
[
page_num
]
tmp
.
append
(
fitz_doc_to_image
(
page
))
images
.
append
(
tmp
)
return
(
idx
,
images
)
def
batch_build_dataset
(
pdf_paths
,
k
,
lang
=
None
):
"""Process multiple PDFs by partitioning them into k balanced parts and
processing each part in parallel.
Parameters:
-----------
pdf_paths : list
List of paths to PDF files
k : int
Number of partitions to create
output_dir : str or None
Directory to save images to
threads_per_worker : int
Number of threads to use per worker
**kwargs :
Additional arguments for process_pdf_page
Returns:
--------
all_images : list
List of all processed images
"""
# Get page counts for each PDF
pdf_info
=
[]
total_pages
=
0
for
pdf_path
in
pdf_paths
:
try
:
doc
=
fitz
.
open
(
pdf_path
)
num_pages
=
len
(
doc
)
pdf_info
.
append
((
pdf_path
,
num_pages
))
total_pages
+=
num_pages
doc
.
close
()
except
Exception
as
e
:
print
(
f
'Error opening
{
pdf_path
}
:
{
e
}
'
)
# Partition the jobs based on page countEach job has 1 page
partitions
=
partition_array_greedy
(
pdf_info
,
k
)
# Process each partition in parallel
all_images_h
=
{}
with
concurrent
.
futures
.
ProcessPoolExecutor
(
max_workers
=
k
)
as
executor
:
# Submit one task per partition
futures
=
[]
for
sn
,
partition
in
enumerate
(
partitions
):
# Get the jobs for this partition
partition_jobs
=
[
pdf_info
[
idx
]
for
idx
in
partition
]
# Submit the task
future
=
executor
.
submit
(
process_pdf_batch
,
partition_jobs
,
sn
)
futures
.
append
(
future
)
# Process results as they complete
for
i
,
future
in
enumerate
(
concurrent
.
futures
.
as_completed
(
futures
)):
try
:
idx
,
images
=
future
.
result
()
all_images_h
[
idx
]
=
images
except
Exception
as
e
:
print
(
f
'Error processing partition:
{
e
}
'
)
results
=
[
None
]
*
len
(
pdf_paths
)
for
i
in
range
(
len
(
partitions
)):
partition
=
partitions
[
i
]
for
j
in
range
(
len
(
partition
)):
with
open
(
pdf_info
[
partition
[
j
]][
0
],
'rb'
)
as
f
:
pdf_bytes
=
f
.
read
()
dataset
=
PymuDocDataset
(
pdf_bytes
,
lang
=
lang
)
dataset
.
set_images
(
all_images_h
[
i
][
j
])
results
[
partition
[
j
]]
=
dataset
return
results
magic_pdf/data/dataset.py
View file @
ecdd162f
...
@@ -97,10 +97,10 @@ class Dataset(ABC):
...
@@ -97,10 +97,10 @@ class Dataset(ABC):
@
abstractmethod
@
abstractmethod
def
dump_to_file
(
self
,
file_path
:
str
):
def
dump_to_file
(
self
,
file_path
:
str
):
"""Dump the file
"""Dump the file
.
Args:
Args:
file_path (str): the file path
file_path (str): the file path
"""
"""
pass
pass
...
@@ -119,7 +119,7 @@ class Dataset(ABC):
...
@@ -119,7 +119,7 @@ class Dataset(ABC):
@
abstractmethod
@
abstractmethod
def
classify
(
self
)
->
SupportedPdfParseMethod
:
def
classify
(
self
)
->
SupportedPdfParseMethod
:
"""classify the dataset
"""classify the dataset
.
Returns:
Returns:
SupportedPdfParseMethod: _description_
SupportedPdfParseMethod: _description_
...
@@ -128,8 +128,7 @@ class Dataset(ABC):
...
@@ -128,8 +128,7 @@ class Dataset(ABC):
@
abstractmethod
@
abstractmethod
def
clone
(
self
):
def
clone
(
self
):
"""clone this dataset
"""clone this dataset."""
"""
pass
pass
...
@@ -148,12 +147,14 @@ class PymuDocDataset(Dataset):
...
@@ -148,12 +147,14 @@ class PymuDocDataset(Dataset):
if
lang
==
''
:
if
lang
==
''
:
self
.
_lang
=
None
self
.
_lang
=
None
elif
lang
==
'auto'
:
elif
lang
==
'auto'
:
from
magic_pdf.model.sub_modules.language_detection.utils
import
auto_detect_lang
from
magic_pdf.model.sub_modules.language_detection.utils
import
\
auto_detect_lang
self
.
_lang
=
auto_detect_lang
(
bits
)
self
.
_lang
=
auto_detect_lang
(
bits
)
logger
.
info
(
f
"
lang:
{
lang
}
, detect_lang:
{
self
.
_lang
}
"
)
logger
.
info
(
f
'
lang:
{
lang
}
, detect_lang:
{
self
.
_lang
}
'
)
else
:
else
:
self
.
_lang
=
lang
self
.
_lang
=
lang
logger
.
info
(
f
"lang:
{
lang
}
"
)
logger
.
info
(
f
'lang:
{
lang
}
'
)
def
__len__
(
self
)
->
int
:
def
__len__
(
self
)
->
int
:
"""The page number of the pdf."""
"""The page number of the pdf."""
return
len
(
self
.
_records
)
return
len
(
self
.
_records
)
...
@@ -186,12 +187,12 @@ class PymuDocDataset(Dataset):
...
@@ -186,12 +187,12 @@ class PymuDocDataset(Dataset):
return
self
.
_records
[
page_id
]
return
self
.
_records
[
page_id
]
def
dump_to_file
(
self
,
file_path
:
str
):
def
dump_to_file
(
self
,
file_path
:
str
):
"""Dump the file
"""Dump the file
.
Args:
Args:
file_path (str): the file path
file_path (str): the file path
"""
"""
dir_name
=
os
.
path
.
dirname
(
file_path
)
dir_name
=
os
.
path
.
dirname
(
file_path
)
if
dir_name
not
in
(
''
,
'.'
,
'..'
):
if
dir_name
not
in
(
''
,
'.'
,
'..'
):
os
.
makedirs
(
dir_name
,
exist_ok
=
True
)
os
.
makedirs
(
dir_name
,
exist_ok
=
True
)
...
@@ -212,7 +213,7 @@ class PymuDocDataset(Dataset):
...
@@ -212,7 +213,7 @@ class PymuDocDataset(Dataset):
return
proc
(
self
,
*
args
,
**
kwargs
)
return
proc
(
self
,
*
args
,
**
kwargs
)
def
classify
(
self
)
->
SupportedPdfParseMethod
:
def
classify
(
self
)
->
SupportedPdfParseMethod
:
"""classify the dataset
"""classify the dataset
.
Returns:
Returns:
SupportedPdfParseMethod: _description_
SupportedPdfParseMethod: _description_
...
@@ -220,10 +221,12 @@ class PymuDocDataset(Dataset):
...
@@ -220,10 +221,12 @@ class PymuDocDataset(Dataset):
return
classify
(
self
.
_data_bits
)
return
classify
(
self
.
_data_bits
)
def
clone
(
self
):
def
clone
(
self
):
"""clone this dataset
"""clone this dataset."""
"""
return
PymuDocDataset
(
self
.
_raw_data
)
return
PymuDocDataset
(
self
.
_raw_data
)
def
set_images
(
self
,
images
):
for
i
in
range
(
len
(
self
.
_records
)):
self
.
_records
[
i
].
set_image
(
images
[
i
])
class
ImageDataset
(
Dataset
):
class
ImageDataset
(
Dataset
):
def
__init__
(
self
,
bits
:
bytes
):
def
__init__
(
self
,
bits
:
bytes
):
...
@@ -270,10 +273,10 @@ class ImageDataset(Dataset):
...
@@ -270,10 +273,10 @@ class ImageDataset(Dataset):
return
self
.
_records
[
page_id
]
return
self
.
_records
[
page_id
]
def
dump_to_file
(
self
,
file_path
:
str
):
def
dump_to_file
(
self
,
file_path
:
str
):
"""Dump the file
"""Dump the file
.
Args:
Args:
file_path (str): the file path
file_path (str): the file path
"""
"""
dir_name
=
os
.
path
.
dirname
(
file_path
)
dir_name
=
os
.
path
.
dirname
(
file_path
)
if
dir_name
not
in
(
''
,
'.'
,
'..'
):
if
dir_name
not
in
(
''
,
'.'
,
'..'
):
...
@@ -293,7 +296,7 @@ class ImageDataset(Dataset):
...
@@ -293,7 +296,7 @@ class ImageDataset(Dataset):
return
proc
(
self
,
*
args
,
**
kwargs
)
return
proc
(
self
,
*
args
,
**
kwargs
)
def
classify
(
self
)
->
SupportedPdfParseMethod
:
def
classify
(
self
)
->
SupportedPdfParseMethod
:
"""classify the dataset
"""classify the dataset
.
Returns:
Returns:
SupportedPdfParseMethod: _description_
SupportedPdfParseMethod: _description_
...
@@ -301,15 +304,19 @@ class ImageDataset(Dataset):
...
@@ -301,15 +304,19 @@ class ImageDataset(Dataset):
return
SupportedPdfParseMethod
.
OCR
return
SupportedPdfParseMethod
.
OCR
def
clone
(
self
):
def
clone
(
self
):
"""clone this dataset
"""clone this dataset."""
"""
return
ImageDataset
(
self
.
_raw_data
)
return
ImageDataset
(
self
.
_raw_data
)
def
set_images
(
self
,
images
):
for
i
in
range
(
len
(
self
.
_records
)):
self
.
_records
[
i
].
set_image
(
images
[
i
])
class
Doc
(
PageableData
):
class
Doc
(
PageableData
):
"""Initialized with pymudoc object."""
"""Initialized with pymudoc object."""
def
__init__
(
self
,
doc
:
fitz
.
Page
):
def
__init__
(
self
,
doc
:
fitz
.
Page
):
self
.
_doc
=
doc
self
.
_doc
=
doc
self
.
_img
=
None
def
get_image
(
self
):
def
get_image
(
self
):
"""Return the image info.
"""Return the image info.
...
@@ -321,7 +328,17 @@ class Doc(PageableData):
...
@@ -321,7 +328,17 @@ class Doc(PageableData):
height: int
height: int
}
}
"""
"""
return
fitz_doc_to_image
(
self
.
_doc
)
if
self
.
_img
is
None
:
self
.
_img
=
fitz_doc_to_image
(
self
.
_doc
)
return
self
.
_img
def
set_image
(
self
,
img
):
"""
Args:
img (np.ndarray): the image
"""
if
self
.
_img
is
None
:
self
.
_img
=
img
def
get_doc
(
self
)
->
fitz
.
Page
:
def
get_doc
(
self
)
->
fitz
.
Page
:
"""Get the pymudoc object.
"""Get the pymudoc object.
...
...
magic_pdf/data/utils.py
View file @
ecdd162f
import
multiprocessing
as
mp
import
threading
from
concurrent.futures
import
(
ProcessPoolExecutor
,
ThreadPoolExecutor
,
as_completed
)
import
fitz
import
fitz
import
numpy
as
np
import
numpy
as
np
from
loguru
import
logger
from
loguru
import
logger
...
@@ -65,3 +70,101 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
...
@@ -65,3 +70,101 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
images
.
append
(
img_dict
)
images
.
append
(
img_dict
)
return
images
return
images
def
convert_page
(
bytes_page
):
pdfs
=
fitz
.
open
(
'pdf'
,
bytes_page
)
page
=
pdfs
[
0
]
return
fitz_doc_to_image
(
page
)
def
parallel_process_pdf_safe
(
pages
,
num_workers
=
None
,
**
kwargs
):
"""Process PDF pages in parallel with serialization-safe approach."""
if
num_workers
is
None
:
num_workers
=
mp
.
cpu_count
()
# Process the extracted page data in parallel
with
ProcessPoolExecutor
(
max_workers
=
num_workers
)
as
executor
:
# Process the page data
results
=
list
(
executor
.
map
(
convert_page
,
pages
)
)
return
results
def
threaded_process_pdf
(
pdf_path
,
num_threads
=
4
,
**
kwargs
):
"""Process all pages of a PDF using multiple threads.
Parameters:
-----------
pdf_path : str
Path to the PDF file
num_threads : int
Number of threads to use
**kwargs :
Additional arguments for fitz_doc_to_image
Returns:
--------
images : list
List of processed images, in page order
"""
# Open the PDF
doc
=
fitz
.
open
(
pdf_path
)
num_pages
=
len
(
doc
)
# Create a list to store results in the correct order
results
=
[
None
]
*
num_pages
# Create a thread pool
with
ThreadPoolExecutor
(
max_workers
=
num_threads
)
as
executor
:
# Submit all tasks
futures
=
{}
for
page_num
in
range
(
num_pages
):
page
=
doc
[
page_num
]
future
=
executor
.
submit
(
fitz_doc_to_image
,
page
,
**
kwargs
)
futures
[
future
]
=
page_num
# Process results as they complete with progress bar
for
future
in
as_completed
(
futures
):
page_num
=
futures
[
future
]
try
:
results
[
page_num
]
=
future
.
result
()
except
Exception
as
e
:
print
(
f
'Error processing page
{
page_num
}
:
{
e
}
'
)
results
[
page_num
]
=
None
# Close the document
doc
.
close
()
if
__name__
==
'__main__'
:
pdf
=
fitz
.
open
(
'/tmp/[MS-DOC].pdf'
)
pdf_page
=
[
fitz
.
open
()
for
i
in
range
(
pdf
.
page_count
)]
[
pdf_page
[
i
].
insert_pdf
(
pdf
,
from_page
=
i
,
to_page
=
i
)
for
i
in
range
(
pdf
.
page_count
)]
pdf_page
=
[
v
.
tobytes
()
for
v
in
pdf_page
]
results
=
parallel_process_pdf_safe
(
pdf_page
,
num_workers
=
16
)
# threaded_process_pdf('/tmp/[MS-DOC].pdf', num_threads=16)
""" benchmark results of multi-threaded processing (fitz page to image)
total page nums: 578
thread nums, time cost
1 7.351 sec
2 6.334 sec
4 5.968 sec
8 6.728 sec
16 8.085 sec
"""
""" benchmark results of multi-processor processing (fitz page to image)
total page nums: 578
processor nums, time cost
1 17.170 sec
2 10.170 sec
4 7.841 sec
8 7.900 sec
16 7.984 sec
"""
magic_pdf/model/doc_analyze_by_custom_model.py
View file @
ecdd162f
import
concurrent.futures
as
fut
import
multiprocessing
as
mp
import
os
import
os
import
time
import
time
import
numpy
as
np
import
torch
import
torch
os
.
environ
[
'FLAGS_npu_jit_compile'
]
=
'0'
# 关闭paddle的jit编译
os
.
environ
[
'FLAGS_npu_jit_compile'
]
=
'0'
# 关闭paddle的jit编译
os
.
environ
[
'FLAGS_use_stride_kernel'
]
=
'0'
os
.
environ
[
'FLAGS_use_stride_kernel'
]
=
'0'
os
.
environ
[
'PYTORCH_ENABLE_MPS_FALLBACK'
]
=
'1'
# 让mps可以fallback
os
.
environ
[
'PYTORCH_ENABLE_MPS_FALLBACK'
]
=
'1'
# 让mps可以fallback
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
# 关闭paddle的信号处理
import
paddle
paddle
.
disable_signal_handler
()
from
loguru
import
logger
from
loguru
import
logger
from
magic_pdf.model.batch_analyze
import
BatchAnalyze
from
magic_pdf.model.sub_modules.model_utils
import
get_vram
from
magic_pdf.model.sub_modules.model_utils
import
get_vram
try
:
try
:
...
@@ -30,8 +31,10 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
...
@@ -30,8 +31,10 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_local_models_dir
,
get_local_models_dir
,
get_table_recog_config
)
get_table_recog_config
)
from
magic_pdf.model.model_list
import
MODEL
from
magic_pdf.model.model_list
import
MODEL
from
magic_pdf.operators.models
import
InferenceResult
# from magic_pdf.operators.models import InferenceResult
MIN_BATCH_INFERENCE_SIZE
=
100
class
ModelSingleton
:
class
ModelSingleton
:
_instance
=
None
_instance
=
None
...
@@ -72,9 +75,7 @@ def custom_model_init(
...
@@ -72,9 +75,7 @@ def custom_model_init(
formula_enable
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
table_enable
=
None
,
):
):
model
=
None
model
=
None
if
model_config
.
__model_mode__
==
'lite'
:
if
model_config
.
__model_mode__
==
'lite'
:
logger
.
warning
(
logger
.
warning
(
'The Lite mode is provided for developers to conduct testing only, and the output quality is '
'The Lite mode is provided for developers to conduct testing only, and the output quality is '
...
@@ -132,7 +133,6 @@ def custom_model_init(
...
@@ -132,7 +133,6 @@ def custom_model_init(
return
custom_model
return
custom_model
def
doc_analyze
(
def
doc_analyze
(
dataset
:
Dataset
,
dataset
:
Dataset
,
ocr
:
bool
=
False
,
ocr
:
bool
=
False
,
...
@@ -143,13 +143,165 @@ def doc_analyze(
...
@@ -143,13 +143,165 @@ def doc_analyze(
layout_model
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
table_enable
=
None
,
)
->
InferenceResult
:
one_shot
:
bool
=
True
,
):
end_page_id
=
(
end_page_id
=
(
end_page_id
end_page_id
if
end_page_id
is
not
None
and
end_page_id
>=
0
if
end_page_id
is
not
None
and
end_page_id
>=
0
else
len
(
dataset
)
-
1
else
len
(
dataset
)
-
1
)
)
parallel_count
=
None
if
os
.
environ
.
get
(
'MINERU_PARALLEL_INFERENCE_COUNT'
):
parallel_count
=
int
(
os
.
environ
[
'MINERU_PARALLEL_INFERENCE_COUNT'
])
images
=
[]
page_wh_list
=
[]
for
index
in
range
(
len
(
dataset
)):
if
start_page_id
<=
index
<=
end_page_id
:
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
'img'
])
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
if
one_shot
and
len
(
images
)
>=
MIN_BATCH_INFERENCE_SIZE
:
if
parallel_count
is
None
:
parallel_count
=
2
# should check the gpu memory firstly !
# split images into parallel_count batches
if
parallel_count
>
1
:
batch_size
=
(
len
(
images
)
+
parallel_count
-
1
)
//
parallel_count
batch_images
=
[
images
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images
),
batch_size
)]
else
:
batch_images
=
[
images
]
results
=
[]
parallel_count
=
len
(
batch_images
)
# adjust to real parallel count
# using concurrent.futures to analyze
"""
with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
for future in fut.as_completed(futures):
sn, result = future.result()
result_history[sn] = result
for key in sorted(result_history.keys()):
results.extend(result_history[key])
"""
results
=
[]
pool
=
mp
.
Pool
(
processes
=
parallel_count
)
mapped_results
=
pool
.
starmap
(
may_batch_image_analyze
,
[(
batch_image
,
sn
,
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
for
sn
,
batch_image
in
enumerate
(
batch_images
)])
for
sn
,
result
in
mapped_results
:
results
.
extend
(
result
)
else
:
_
,
results
=
may_batch_image_analyze
(
images
,
0
,
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
model_json
=
[]
for
index
in
range
(
len
(
dataset
)):
if
start_page_id
<=
index
<=
end_page_id
:
result
=
results
.
pop
(
0
)
page_width
,
page_height
=
page_wh_list
.
pop
(
0
)
else
:
result
=
[]
page_height
=
0
page_width
=
0
page_info
=
{
'page_no'
:
index
,
'width'
:
page_width
,
'height'
:
page_height
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
model_json
.
append
(
page_dict
)
from
magic_pdf.operators.models
import
InferenceResult
return
InferenceResult
(
model_json
,
dataset
)
def
batch_doc_analyze
(
datasets
:
list
[
Dataset
],
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
one_shot
:
bool
=
True
,
):
parallel_count
=
None
if
os
.
environ
.
get
(
'MINERU_PARALLEL_INFERENCE_COUNT'
):
parallel_count
=
int
(
os
.
environ
[
'MINERU_PARALLEL_INFERENCE_COUNT'
])
images
=
[]
page_wh_list
=
[]
for
dataset
in
datasets
:
for
index
in
range
(
len
(
dataset
)):
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
'img'
])
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
if
one_shot
and
len
(
images
)
>=
MIN_BATCH_INFERENCE_SIZE
:
if
parallel_count
is
None
:
parallel_count
=
2
# should check the gpu memory firstly !
# split images into parallel_count batches
if
parallel_count
>
1
:
batch_size
=
(
len
(
images
)
+
parallel_count
-
1
)
//
parallel_count
batch_images
=
[
images
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images
),
batch_size
)]
else
:
batch_images
=
[
images
]
results
=
[]
parallel_count
=
len
(
batch_images
)
# adjust to real parallel count
# using concurrent.futures to analyze
"""
with fut.ProcessPoolExecutor(max_workers=parallel_count) as executor:
futures = [executor.submit(may_batch_image_analyze, batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable) for sn, batch_image in enumerate(batch_images)]
for future in fut.as_completed(futures):
sn, result = future.result()
result_history[sn] = result
for key in sorted(result_history.keys()):
results.extend(result_history[key])
"""
results
=
[]
pool
=
mp
.
Pool
(
processes
=
parallel_count
)
mapped_results
=
pool
.
starmap
(
may_batch_image_analyze
,
[(
batch_image
,
sn
,
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
for
sn
,
batch_image
in
enumerate
(
batch_images
)])
for
sn
,
result
in
mapped_results
:
results
.
extend
(
result
)
else
:
_
,
results
=
may_batch_image_analyze
(
images
,
0
,
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
infer_results
=
[]
from
magic_pdf.operators.models
import
InferenceResult
for
index
in
range
(
len
(
datasets
)):
dataset
=
datasets
[
index
]
model_json
=
[]
for
i
in
range
(
len
(
dataset
)):
result
=
results
.
pop
(
0
)
page_width
,
page_height
=
page_wh_list
.
pop
(
0
)
page_info
=
{
'page_no'
:
i
,
'width'
:
page_width
,
'height'
:
page_height
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
model_json
.
append
(
page_dict
)
infer_results
.
append
(
InferenceResult
(
model_json
,
dataset
))
return
infer_results
def
may_batch_image_analyze
(
images
:
list
[
np
.
ndarray
],
idx
:
int
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
# os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
# 关闭paddle的信号处理
import
paddle
paddle
.
disable_signal_handler
()
from
magic_pdf.model.batch_analyze
import
BatchAnalyze
model_manager
=
ModelSingleton
()
model_manager
=
ModelSingleton
()
custom_model
=
model_manager
.
get_model
(
custom_model
=
model_manager
.
get_model
(
...
@@ -161,14 +313,14 @@ def doc_analyze(
...
@@ -161,14 +313,14 @@ def doc_analyze(
device
=
get_device
()
device
=
get_device
()
npu_support
=
False
npu_support
=
False
if
str
(
device
).
startswith
(
"
npu
"
):
if
str
(
device
).
startswith
(
'
npu
'
):
import
torch_npu
import
torch_npu
if
torch_npu
.
npu
.
is_available
():
if
torch_npu
.
npu
.
is_available
():
npu_support
=
True
npu_support
=
True
torch
.
npu
.
set_compile_mode
(
jit_compile
=
False
)
torch
.
npu
.
set_compile_mode
(
jit_compile
=
False
)
if
torch
.
cuda
.
is_available
()
and
device
!=
'cpu'
or
npu_support
:
if
torch
.
cuda
.
is_available
()
and
device
!=
'cpu'
or
npu_support
:
gpu_memory
=
int
(
os
.
getenv
(
"
VIRTUAL_VRAM_SIZE
"
,
round
(
get_vram
(
device
))))
gpu_memory
=
int
(
os
.
getenv
(
'
VIRTUAL_VRAM_SIZE
'
,
round
(
get_vram
(
device
))))
if
gpu_memory
is
not
None
and
gpu_memory
>=
8
:
if
gpu_memory
is
not
None
and
gpu_memory
>=
8
:
if
gpu_memory
>=
20
:
if
gpu_memory
>=
20
:
batch_ratio
=
16
batch_ratio
=
16
...
@@ -181,12 +333,10 @@ def doc_analyze(
...
@@ -181,12 +333,10 @@ def doc_analyze(
logger
.
info
(
f
'gpu_memory:
{
gpu_memory
}
GB, batch_ratio:
{
batch_ratio
}
'
)
logger
.
info
(
f
'gpu_memory:
{
gpu_memory
}
GB, batch_ratio:
{
batch_ratio
}
'
)
batch_analyze
=
True
batch_analyze
=
True
model_json
=
[]
doc_analyze_start
=
time
.
time
()
doc_analyze_start
=
time
.
time
()
if
batch_analyze
:
if
batch_analyze
:
# batch analyze
"""
# batch analyze
images = []
images = []
page_wh_list = []
page_wh_list = []
for index in range(len(dataset)):
for index in range(len(dataset)):
...
@@ -195,9 +345,10 @@ def doc_analyze(
...
@@ -195,9 +345,10 @@ def doc_analyze(
img_dict = page_data.get_image()
img_dict = page_data.get_image()
images.append(img_dict['img'])
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
page_wh_list.append((img_dict['width'], img_dict['height']))
"""
batch_model
=
BatchAnalyze
(
model
=
custom_model
,
batch_ratio
=
batch_ratio
)
batch_model
=
BatchAnalyze
(
model
=
custom_model
,
batch_ratio
=
batch_ratio
)
analyze_
result
=
batch_model
(
images
)
result
s
=
batch_model
(
images
)
"""
for index in range(len(dataset)):
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0)
result = analyze_result.pop(0)
...
@@ -210,10 +361,10 @@ def doc_analyze(
...
@@ -210,10 +361,10 @@ def doc_analyze(
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
model_json.append(page_dict)
"""
else
:
else
:
# single analyze
# single analyze
"""
for index in range(len(dataset)):
for index in range(len(dataset)):
page_data = dataset.get_page(index)
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
img_dict = page_data.get_image()
...
@@ -230,6 +381,13 @@ def doc_analyze(
...
@@ -230,6 +381,13 @@ def doc_analyze(
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
model_json.append(page_dict)
"""
results
=
[]
for
img_idx
,
img
in
enumerate
(
images
):
inference_start
=
time
.
time
()
result
=
custom_model
(
img
)
logger
.
info
(
f
'-----image index :
{
img_idx
}
, image inference total time:
{
round
(
time
.
time
()
-
inference_start
,
2
)
}
-----'
)
results
.
append
(
result
)
gc_start
=
time
.
time
()
gc_start
=
time
.
time
()
clean_memory
(
get_device
())
clean_memory
(
get_device
())
...
@@ -237,10 +395,9 @@ def doc_analyze(
...
@@ -237,10 +395,9 @@ def doc_analyze(
logger
.
info
(
f
'gc time:
{
gc_time
}
'
)
logger
.
info
(
f
'gc time:
{
gc_time
}
'
)
doc_analyze_time
=
round
(
time
.
time
()
-
doc_analyze_start
,
2
)
doc_analyze_time
=
round
(
time
.
time
()
-
doc_analyze_start
,
2
)
doc_analyze_speed
=
round
(
(
en
d_page_id
+
1
-
start_page_id
)
/
doc_analyze_time
,
2
)
doc_analyze_speed
=
round
(
l
en
(
images
)
/
doc_analyze_time
,
2
)
logger
.
info
(
logger
.
info
(
f
'doc analyze time:
{
round
(
time
.
time
()
-
doc_analyze_start
,
2
)
}
,'
f
'doc analyze time:
{
round
(
time
.
time
()
-
doc_analyze_start
,
2
)
}
,'
f
' speed:
{
doc_analyze_speed
}
pages/second'
f
' speed:
{
doc_analyze_speed
}
pages/second'
)
)
return
(
idx
,
results
)
return
InferenceResult
(
model_json
,
dataset
)
magic_pdf/model/sub_modules/model_init.py
View file @
ecdd162f
import
os
import
torch
import
torch
from
loguru
import
logger
from
loguru
import
logger
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.model.model_list
import
AtomicModel
from
magic_pdf.model.model_list
import
AtomicModel
from
magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11
import
YOLOv11LangDetModel
from
magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11
import
\
from
magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO
import
DocLayoutYOLOModel
YOLOv11LangDetModel
from
magic_pdf.model.sub_modules.layout.layoutlmv3.model_init
import
Layoutlmv3_Predictor
from
magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO
import
\
DocLayoutYOLOModel
from
magic_pdf.model.sub_modules.layout.layoutlmv3.model_init
import
\
Layoutlmv3_Predictor
from
magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8
import
YOLOv8MFDModel
from
magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8
import
YOLOv8MFDModel
from
magic_pdf.model.sub_modules.mfr.unimernet.Unimernet
import
UnimernetModel
from
magic_pdf.model.sub_modules.mfr.unimernet.Unimernet
import
UnimernetModel
try
:
try
:
from
magic_pdf_ascend_plugin.libs.license_verifier
import
load_license
,
LicenseFormatError
,
LicenseSignatureError
,
LicenseExpiredError
from
magic_pdf_ascend_plugin.libs.license_verifier
import
(
from
magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu
import
ModifiedPaddleOCR
LicenseExpiredError
,
LicenseFormatError
,
LicenseSignatureError
,
from
magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu
import
RapidTableModel
load_license
)
from
magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu
import
\
ModifiedPaddleOCR
from
magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu
import
\
RapidTableModel
license_key
=
load_license
()
license_key
=
load_license
()
logger
.
info
(
f
'Using Ascend Plugin Success, License id is
{
license_key
[
"payload"
][
"id"
]
}
,'
logger
.
info
(
f
'Using Ascend Plugin Success, License id is
{
license_key
[
"payload"
][
"id"
]
}
,'
f
' License expired at
{
license_key
[
"payload"
][
"date"
][
"end_date"
]
}
'
)
f
' License expired at
{
license_key
[
"payload"
][
"date"
][
"end_date"
]
}
'
)
...
@@ -20,21 +29,24 @@ except Exception as e:
...
@@ -20,21 +29,24 @@ except Exception as e:
if
isinstance
(
e
,
ImportError
):
if
isinstance
(
e
,
ImportError
):
pass
pass
elif
isinstance
(
e
,
LicenseFormatError
):
elif
isinstance
(
e
,
LicenseFormatError
):
logger
.
error
(
"
Ascend Plugin: Invalid license format. Please check the license file.
"
)
logger
.
error
(
'
Ascend Plugin: Invalid license format. Please check the license file.
'
)
elif
isinstance
(
e
,
LicenseSignatureError
):
elif
isinstance
(
e
,
LicenseSignatureError
):
logger
.
error
(
"
Ascend Plugin: Invalid signature. The license may be tampered with.
"
)
logger
.
error
(
'
Ascend Plugin: Invalid signature. The license may be tampered with.
'
)
elif
isinstance
(
e
,
LicenseExpiredError
):
elif
isinstance
(
e
,
LicenseExpiredError
):
logger
.
error
(
"
Ascend Plugin: License has expired. Please renew your license.
"
)
logger
.
error
(
'
Ascend Plugin: License has expired. Please renew your license.
'
)
elif
isinstance
(
e
,
FileNotFoundError
):
elif
isinstance
(
e
,
FileNotFoundError
):
logger
.
error
(
"
Ascend Plugin: Not found License file.
"
)
logger
.
error
(
'
Ascend Plugin: Not found License file.
'
)
else
:
else
:
logger
.
error
(
f
"
Ascend Plugin:
{
e
}
"
)
logger
.
error
(
f
'
Ascend Plugin:
{
e
}
'
)
from
magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod
import
ModifiedPaddleOCR
from
magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod
import
ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from
magic_pdf.model.sub_modules.table.rapidtable.rapid_table
import
RapidTableModel
from
magic_pdf.model.sub_modules.table.rapidtable.rapid_table
import
RapidTableModel
from
magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable
import
StructTableModel
from
magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable
import
\
from
magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle
import
TableMasterPaddleModel
StructTableModel
from
magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle
import
\
TableMasterPaddleModel
def
table_model_init
(
table_model_type
,
model_path
,
max_time
,
_device_
=
'cpu'
,
ocr_engine
=
None
,
table_sub_model_name
=
None
):
def
table_model_init
(
table_model_type
,
model_path
,
max_time
,
_device_
=
'cpu'
,
ocr_engine
=
None
,
table_sub_model_name
=
None
):
if
table_model_type
==
MODEL_NAME
.
STRUCT_EQTABLE
:
if
table_model_type
==
MODEL_NAME
.
STRUCT_EQTABLE
:
...
@@ -55,7 +67,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
...
@@ -55,7 +67,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
def
mfd_model_init
(
weight
,
device
=
'cpu'
):
def
mfd_model_init
(
weight
,
device
=
'cpu'
):
if
str
(
device
).
startswith
(
"
npu
"
):
if
str
(
device
).
startswith
(
'
npu
'
):
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
mfd_model
=
YOLOv8MFDModel
(
weight
,
device
)
mfd_model
=
YOLOv8MFDModel
(
weight
,
device
)
return
mfd_model
return
mfd_model
...
@@ -72,14 +84,14 @@ def layout_model_init(weight, config_file, device):
...
@@ -72,14 +84,14 @@ def layout_model_init(weight, config_file, device):
def
doclayout_yolo_model_init
(
weight
,
device
=
'cpu'
):
def
doclayout_yolo_model_init
(
weight
,
device
=
'cpu'
):
if
str
(
device
).
startswith
(
"
npu
"
):
if
str
(
device
).
startswith
(
'
npu
'
):
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
model
=
DocLayoutYOLOModel
(
weight
,
device
)
model
=
DocLayoutYOLOModel
(
weight
,
device
)
return
model
return
model
def
langdetect_model_init
(
langdetect_model_weight
,
device
=
'cpu'
):
def
langdetect_model_init
(
langdetect_model_weight
,
device
=
'cpu'
):
if
str
(
device
).
startswith
(
"
npu
"
):
if
str
(
device
).
startswith
(
'
npu
'
):
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
model
=
YOLOv11LangDetModel
(
langdetect_model_weight
,
device
)
model
=
YOLOv11LangDetModel
(
langdetect_model_weight
,
device
)
return
model
return
model
...
...
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py
View file @
ecdd162f
...
@@ -5,6 +5,7 @@ import cv2
...
@@ -5,6 +5,7 @@ import cv2
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
paddleocr
import
PaddleOCR
from
paddleocr
import
PaddleOCR
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.utility
import
alpha_to_color
,
binarize_img
from
ppocr.utils.utility
import
alpha_to_color
,
binarize_img
...
...
magic_pdf/model/sub_modules/table/tablemaster/tablemaster_paddle.py
View file @
ecdd162f
...
@@ -2,6 +2,7 @@ import os
...
@@ -2,6 +2,7 @@ import os
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
from
paddleocr
import
PaddleOCR
from
ppstructure.table.predict_table
import
TableSystem
from
ppstructure.table.predict_table
import
TableSystem
from
ppstructure.utility
import
init_args
from
ppstructure.utility
import
init_args
from
PIL
import
Image
from
PIL
import
Image
...
...
magic_pdf/tools/cli.py
View file @
ecdd162f
import
os
import
os
import
shutil
import
shutil
import
tempfile
import
tempfile
from
pathlib
import
Path
import
click
import
click
import
fitz
import
fitz
from
loguru
import
logger
from
loguru
import
logger
from
pathlib
import
Path
import
magic_pdf.model
as
model_config
import
magic_pdf.model
as
model_config
from
magic_pdf.data.batch_build_dataset
import
batch_build_dataset
from
magic_pdf.data.data_reader_writer
import
FileBasedDataReader
from
magic_pdf.data.data_reader_writer
import
FileBasedDataReader
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.version
import
__version__
from
magic_pdf.libs.version
import
__version__
from
magic_pdf.tools.common
import
do_parse
,
parse_pdf_methods
from
magic_pdf.tools.common
import
batch_do_parse
,
do_parse
,
parse_pdf_methods
from
magic_pdf.utils.office_to_pdf
import
convert_file_to_pdf
from
magic_pdf.utils.office_to_pdf
import
convert_file_to_pdf
pdf_suffixes
=
[
'.pdf'
]
pdf_suffixes
=
[
'.pdf'
]
...
@@ -94,30 +97,33 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
...
@@ -94,30 +97,33 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
def
read_fn
(
path
:
Path
):
def
read_fn
(
path
:
Path
):
if
path
.
suffix
in
ms_office_suffixes
:
if
path
.
suffix
in
ms_office_suffixes
:
convert_file_to_pdf
(
str
(
path
),
temp_dir
)
convert_file_to_pdf
(
str
(
path
),
temp_dir
)
fn
=
os
.
path
.
join
(
temp_dir
,
f
"
{
path
.
stem
}
.pdf
"
)
fn
=
os
.
path
.
join
(
temp_dir
,
f
'
{
path
.
stem
}
.pdf
'
)
elif
path
.
suffix
in
image_suffixes
:
elif
path
.
suffix
in
image_suffixes
:
with
open
(
str
(
path
),
'rb'
)
as
f
:
with
open
(
str
(
path
),
'rb'
)
as
f
:
bits
=
f
.
read
()
bits
=
f
.
read
()
pdf_bytes
=
fitz
.
open
(
stream
=
bits
).
convert_to_pdf
()
pdf_bytes
=
fitz
.
open
(
stream
=
bits
).
convert_to_pdf
()
fn
=
os
.
path
.
join
(
temp_dir
,
f
"
{
path
.
stem
}
.pdf
"
)
fn
=
os
.
path
.
join
(
temp_dir
,
f
'
{
path
.
stem
}
.pdf
'
)
with
open
(
fn
,
'wb'
)
as
f
:
with
open
(
fn
,
'wb'
)
as
f
:
f
.
write
(
pdf_bytes
)
f
.
write
(
pdf_bytes
)
elif
path
.
suffix
in
pdf_suffixes
:
elif
path
.
suffix
in
pdf_suffixes
:
fn
=
str
(
path
)
fn
=
str
(
path
)
else
:
else
:
raise
Exception
(
f
"
Unknown file suffix:
{
path
.
suffix
}
"
)
raise
Exception
(
f
'
Unknown file suffix:
{
path
.
suffix
}
'
)
disk_rw
=
FileBasedDataReader
(
os
.
path
.
dirname
(
fn
))
disk_rw
=
FileBasedDataReader
(
os
.
path
.
dirname
(
fn
))
return
disk_rw
.
read
(
os
.
path
.
basename
(
fn
))
return
disk_rw
.
read
(
os
.
path
.
basename
(
fn
))
def
parse_doc
(
doc_path
:
Path
):
def
parse_doc
(
doc_path
:
Path
,
dataset
:
Dataset
|
None
=
None
):
try
:
try
:
file_name
=
str
(
Path
(
doc_path
).
stem
)
file_name
=
str
(
Path
(
doc_path
).
stem
)
pdf_data
=
read_fn
(
doc_path
)
if
dataset
is
None
:
pdf_data_or_dataset
=
read_fn
(
doc_path
)
else
:
pdf_data_or_dataset
=
dataset
do_parse
(
do_parse
(
output_dir
,
output_dir
,
file_name
,
file_name
,
pdf_data
,
pdf_data
_or_dataset
,
[],
[],
method
,
method
,
debug_able
,
debug_able
,
...
@@ -130,9 +136,12 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
...
@@ -130,9 +136,12 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
logger
.
exception
(
e
)
logger
.
exception
(
e
)
if
os
.
path
.
isdir
(
path
):
if
os
.
path
.
isdir
(
path
):
doc_paths
=
[]
for
doc_path
in
Path
(
path
).
glob
(
'*'
):
for
doc_path
in
Path
(
path
).
glob
(
'*'
):
if
doc_path
.
suffix
in
pdf_suffixes
+
image_suffixes
+
ms_office_suffixes
:
if
doc_path
.
suffix
in
pdf_suffixes
+
image_suffixes
+
ms_office_suffixes
:
parse_doc
(
doc_path
)
doc_paths
.
append
(
doc_path
)
datasets
=
batch_build_dataset
(
doc_paths
,
4
,
lang
)
batch_do_parse
(
output_dir
,
[
str
(
doc_path
.
stem
)
for
doc_path
in
doc_paths
],
datasets
,
method
,
debug_able
,
lang
=
lang
)
else
:
else
:
parse_doc
(
Path
(
path
))
parse_doc
(
Path
(
path
))
...
...
magic_pdf/tools/common.py
View file @
ecdd162f
...
@@ -8,10 +8,10 @@ import magic_pdf.model as model_config
...
@@ -8,10 +8,10 @@ import magic_pdf.model as model_config
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
from
magic_pdf.config.make_content_config
import
DropMode
,
MakeMode
from
magic_pdf.config.make_content_config
import
DropMode
,
MakeMode
from
magic_pdf.data.data_reader_writer
import
FileBasedDataWriter
from
magic_pdf.data.data_reader_writer
import
FileBasedDataWriter
from
magic_pdf.data.dataset
import
PymuDocDataset
from
magic_pdf.data.dataset
import
Dataset
,
PymuDocDataset
from
magic_pdf.libs.draw_bbox
import
draw_char_bbox
from
magic_pdf.libs.draw_bbox
import
draw_char_bbox
from
magic_pdf.model.doc_analyze_by_custom_model
import
doc_analyze
from
magic_pdf.model.doc_analyze_by_custom_model
import
(
batch_
doc_analyze
,
from
magic_pdf.operators.models
import
InferenceResult
doc_analyze
)
# from io import BytesIO
# from io import BytesIO
# from pypdf import PdfReader, PdfWriter
# from pypdf import PdfReader, PdfWriter
...
@@ -67,10 +67,10 @@ def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_i
...
@@ -67,10 +67,10 @@ def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_i
return
output_bytes
return
output_bytes
def
do_parse
(
def
_
do_parse
(
output_dir
,
output_dir
,
pdf_file_name
,
pdf_file_name
,
pdf_bytes
,
pdf_bytes
_or_dataset
,
model_list
,
model_list
,
parse_method
,
parse_method
,
debug_able
,
debug_able
,
...
@@ -92,16 +92,21 @@ def do_parse(
...
@@ -92,16 +92,21 @@ def do_parse(
formula_enable
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
table_enable
=
None
,
):
):
from
magic_pdf.operators.models
import
InferenceResult
if
debug_able
:
if
debug_able
:
logger
.
warning
(
'debug mode is on'
)
logger
.
warning
(
'debug mode is on'
)
f_draw_model_bbox
=
True
f_draw_model_bbox
=
True
f_draw_line_sort_bbox
=
True
f_draw_line_sort_bbox
=
True
# f_draw_char_bbox = True
# f_draw_char_bbox = True
pdf_bytes
=
convert_pdf_bytes_to_bytes_by_pymupdf
(
if
isinstance
(
pdf_bytes_or_dataset
,
bytes
):
pdf_bytes
,
start_page_id
,
end_page_id
pdf_bytes
=
convert_pdf_bytes_to_bytes_by_pymupdf
(
)
pdf_bytes_or_dataset
,
start_page_id
,
end_page_id
)
ds
=
PymuDocDataset
(
pdf_bytes
,
lang
=
lang
)
else
:
ds
=
pdf_bytes_or_dataset
pdf_bytes
=
ds
.
_raw_data
local_image_dir
,
local_md_dir
=
prepare_env
(
output_dir
,
pdf_file_name
,
parse_method
)
local_image_dir
,
local_md_dir
=
prepare_env
(
output_dir
,
pdf_file_name
,
parse_method
)
image_writer
,
md_writer
=
FileBasedDataWriter
(
local_image_dir
),
FileBasedDataWriter
(
image_writer
,
md_writer
=
FileBasedDataWriter
(
local_image_dir
),
FileBasedDataWriter
(
...
@@ -109,8 +114,6 @@ def do_parse(
...
@@ -109,8 +114,6 @@ def do_parse(
)
)
image_dir
=
str
(
os
.
path
.
basename
(
local_image_dir
))
image_dir
=
str
(
os
.
path
.
basename
(
local_image_dir
))
ds
=
PymuDocDataset
(
pdf_bytes
,
lang
=
lang
)
if
len
(
model_list
)
==
0
:
if
len
(
model_list
)
==
0
:
if
model_config
.
__use_inside_model__
:
if
model_config
.
__use_inside_model__
:
if
parse_method
==
'auto'
:
if
parse_method
==
'auto'
:
...
@@ -241,5 +244,79 @@ def do_parse(
...
@@ -241,5 +244,79 @@ def do_parse(
logger
.
info
(
f
'local output dir is
{
local_md_dir
}
'
)
logger
.
info
(
f
'local output dir is
{
local_md_dir
}
'
)
def
do_parse
(
output_dir
,
pdf_file_name
,
pdf_bytes_or_dataset
,
model_list
,
parse_method
,
debug_able
,
f_draw_span_bbox
=
True
,
f_draw_layout_bbox
=
True
,
f_dump_md
=
True
,
f_dump_middle_json
=
True
,
f_dump_model_json
=
True
,
f_dump_orig_pdf
=
True
,
f_dump_content_list
=
True
,
f_make_md_mode
=
MakeMode
.
MM_MD
,
f_draw_model_bbox
=
False
,
f_draw_line_sort_bbox
=
False
,
f_draw_char_bbox
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
):
parallel_count
=
1
if
os
.
environ
.
get
(
'MINERU_PARALLEL_INFERENCE_COUNT'
):
parallel_count
=
int
(
os
.
environ
[
'MINERU_PARALLEL_INFERENCE_COUNT'
])
if
parallel_count
>
1
:
if
isinstance
(
pdf_bytes_or_dataset
,
bytes
):
pdf_bytes
=
convert_pdf_bytes_to_bytes_by_pymupdf
(
pdf_bytes_or_dataset
,
start_page_id
,
end_page_id
)
ds
=
PymuDocDataset
(
pdf_bytes
,
lang
=
lang
)
else
:
ds
=
pdf_bytes_or_dataset
batch_do_parse
(
output_dir
,
[
pdf_file_name
],
[
ds
],
parse_method
,
debug_able
,
f_draw_span_bbox
=
f_draw_span_bbox
,
f_draw_layout_bbox
=
f_draw_layout_bbox
,
f_dump_md
=
f_dump_md
,
f_dump_middle_json
=
f_dump_middle_json
,
f_dump_model_json
=
f_dump_model_json
,
f_dump_orig_pdf
=
f_dump_orig_pdf
,
f_dump_content_list
=
f_dump_content_list
,
f_make_md_mode
=
f_make_md_mode
,
f_draw_model_bbox
=
f_draw_model_bbox
,
f_draw_line_sort_bbox
=
f_draw_line_sort_bbox
,
f_draw_char_bbox
=
f_draw_char_bbox
)
else
:
_do_parse
(
output_dir
,
pdf_file_name
,
pdf_bytes_or_dataset
,
model_list
,
parse_method
,
debug_able
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
,
f_draw_span_bbox
=
f_draw_span_bbox
,
f_draw_layout_bbox
=
f_draw_layout_bbox
,
f_dump_md
=
f_dump_md
,
f_dump_middle_json
=
f_dump_middle_json
,
f_dump_model_json
=
f_dump_model_json
,
f_dump_orig_pdf
=
f_dump_orig_pdf
,
f_dump_content_list
=
f_dump_content_list
,
f_make_md_mode
=
f_make_md_mode
,
f_draw_model_bbox
=
f_draw_model_bbox
,
f_draw_line_sort_bbox
=
f_draw_line_sort_bbox
,
f_draw_char_bbox
=
f_draw_char_bbox
)
def
batch_do_parse
(
output_dir
,
pdf_file_names
:
list
[
str
],
pdf_bytes_or_datasets
:
list
[
bytes
|
Dataset
],
parse_method
,
debug_able
,
f_draw_span_bbox
=
True
,
f_draw_layout_bbox
=
True
,
f_dump_md
=
True
,
f_dump_middle_json
=
True
,
f_dump_model_json
=
True
,
f_dump_orig_pdf
=
True
,
f_dump_content_list
=
True
,
f_make_md_mode
=
MakeMode
.
MM_MD
,
f_draw_model_bbox
=
False
,
f_draw_line_sort_bbox
=
False
,
f_draw_char_bbox
=
False
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
):
dss
=
[]
for
v
in
pdf_bytes_or_datasets
:
if
isinstance
(
v
,
bytes
):
dss
.
append
(
PymuDocDataset
(
v
,
lang
=
lang
))
else
:
dss
.
append
(
v
)
infer_results
=
batch_doc_analyze
(
dss
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
,
one_shot
=
True
)
for
idx
,
infer_result
in
enumerate
(
infer_results
):
_do_parse
(
output_dir
,
pdf_file_names
[
idx
],
dss
[
idx
],
infer_result
.
get_infer_res
(),
parse_method
,
debug_able
,
f_draw_span_bbox
=
f_draw_span_bbox
,
f_draw_layout_bbox
=
f_draw_layout_bbox
,
f_dump_md
=
f_dump_md
,
f_dump_middle_json
=
f_dump_middle_json
,
f_dump_model_json
=
f_dump_model_json
,
f_dump_orig_pdf
=
f_dump_orig_pdf
,
f_dump_content_list
=
f_dump_content_list
,
f_make_md_mode
=
f_make_md_mode
,
f_draw_model_bbox
=
f_draw_model_bbox
,
f_draw_line_sort_bbox
=
f_draw_line_sort_bbox
,
f_draw_char_bbox
=
f_draw_char_bbox
)
parse_pdf_methods
=
click
.
Choice
([
'ocr'
,
'txt'
,
'auto'
])
parse_pdf_methods
=
click
.
Choice
([
'ocr'
,
'txt'
,
'auto'
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment