Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
9ce72d78
"tests/vscode:/vscode.git/clone" did not exist on "2bfa5a61fb21e03cb3e70b0cdace7bd8466a2817"
Commit
9ce72d78
authored
Mar 20, 2025
by
myhloli
Browse files
Merge remote-tracking branch 'origin/dev' into dev
parents
59435d88
27281c92
Changes
39
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3238 additions
and
357 deletions
+3238
-357
docker/ascend_npu/requirements.txt
docker/ascend_npu/requirements.txt
+3
-7
docker/china/requirements.txt
docker/china/requirements.txt
+3
-7
docker/global/requirements.txt
docker/global/requirements.txt
+3
-7
magic-pdf.template.json
magic-pdf.template.json
+1
-1
magic_pdf/data/batch_build_dataset.py
magic_pdf/data/batch_build_dataset.py
+156
-0
magic_pdf/data/dataset.py
magic_pdf/data/dataset.py
+40
-23
magic_pdf/data/utils.py
magic_pdf/data/utils.py
+108
-9
magic_pdf/dict2md/ocr_mkcontent.py
magic_pdf/dict2md/ocr_mkcontent.py
+4
-3
magic_pdf/libs/pdf_image_tools.py
magic_pdf/libs/pdf_image_tools.py
+11
-6
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+5
-116
magic_pdf/model/doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+130
-28
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+4
-29
magic_pdf/model/sub_modules/language_detection/utils.py
magic_pdf/model/sub_modules/language_detection/utils.py
+2
-4
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py
...f/model/sub_modules/language_detection/yolov11/YOLOv11.py
+24
-19
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
+20
-98
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py
.../model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py
+13
-0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py
..._modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py
+189
-0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py
...dules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py
+8
-0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py
...t/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py
+163
-0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py
...mernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py
+2351
-0
No files found.
docker/ascend_npu/requirements.txt
View file @
9ce72d78
...
@@ -7,19 +7,15 @@ numpy>=1.21.6,<2.0.0
...
@@ -7,19 +7,15 @@ numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0
fast-langdetect>=0.2.3,<0.3.0
scikit-learn>=1.0.2
scikit-learn>=1.0.2
pdfminer.six==20231228
pdfminer.six==20231228
unimernet==0.2.3
torch==2.3.1
torch>=2.2.2,<=2.3.1
torchvision==0.18.1
torchvision>=0.17.2,<=0.18.1
matplotlib
matplotlib
ultralytics>=8.3.48
ultralytics>=8.3.48
paddleocr==2.7.3
paddleocr==2.7.3
paddlepaddle==3.0.0rc1
paddlepaddle==3.0.0rc1
struct-eqtable==0.3.2
einops
accelerate
rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0
rapid-table>=1.0.3,<2.0.0
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
doclayout-yolo==0.0.2b1
ftfy
openai
openai
detectron2
docker/china/requirements.txt
View file @
9ce72d78
...
@@ -7,18 +7,14 @@ numpy>=1.21.6,<2.0.0
...
@@ -7,18 +7,14 @@ numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0
fast-langdetect>=0.2.3,<0.3.0
scikit-learn>=1.0.2
scikit-learn>=1.0.2
pdfminer.six==20231228
pdfminer.six==20231228
unimernet==0.2.3
torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
torch>=2.2.2,<=2.3.1
torchvision
torchvision>=0.17.2,<=0.18.1
matplotlib
matplotlib
ultralytics>=8.3.48
ultralytics>=8.3.48
paddleocr==2.7.3
paddleocr==2.7.3
struct-eqtable==0.3.2
einops
accelerate
rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0
rapid-table>=1.0.3,<2.0.0
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
doclayout-yolo==0.0.2b1
ftfy
openai
openai
detectron2
docker/global/requirements.txt
View file @
9ce72d78
...
@@ -7,18 +7,14 @@ numpy>=1.21.6,<2.0.0
...
@@ -7,18 +7,14 @@ numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0
fast-langdetect>=0.2.3,<0.3.0
scikit-learn>=1.0.2
scikit-learn>=1.0.2
pdfminer.six==20231228
pdfminer.six==20231228
unimernet==0.2.3
torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
torch>=2.2.2,<=2.3.1
torchvision
torchvision>=0.17.2,<=0.18.1
matplotlib
matplotlib
ultralytics>=8.3.48
ultralytics>=8.3.48
paddleocr==2.7.3
paddleocr==2.7.3
struct-eqtable==0.3.2
einops
accelerate
rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0
rapid-table>=1.0.3,<2.0.0
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
doclayout-yolo==0.0.2b1
ftfy
openai
openai
detectron2
magic-pdf.template.json
View file @
9ce72d78
...
@@ -40,5 +40,5 @@
...
@@ -40,5 +40,5 @@
"enable"
:
false
"enable"
:
false
}
}
},
},
"config_version"
:
"1.
1.1
"
"config_version"
:
"1.
2.0
"
}
}
\ No newline at end of file
magic_pdf/data/batch_build_dataset.py
0 → 100644
View file @
9ce72d78
import
concurrent.futures
import
fitz
from
magic_pdf.data.dataset
import
PymuDocDataset
from
magic_pdf.data.utils
import
fitz_doc_to_image
# PyMuPDF
def
partition_array_greedy
(
arr
,
k
):
"""Partition an array into k parts using a simple greedy approach.
Parameters:
-----------
arr : list
The input array of integers
k : int
Number of partitions to create
Returns:
--------
partitions : list of lists
The k partitions of the array
"""
# Handle edge cases
if
k
<=
0
:
raise
ValueError
(
'k must be a positive integer'
)
if
k
>
len
(
arr
):
k
=
len
(
arr
)
# Adjust k if it's too large
if
k
==
1
:
return
[
list
(
range
(
len
(
arr
)))]
if
k
==
len
(
arr
):
return
[[
i
]
for
i
in
range
(
len
(
arr
))]
# Sort the array in descending order
sorted_indices
=
sorted
(
range
(
len
(
arr
)),
key
=
lambda
i
:
arr
[
i
][
1
],
reverse
=
True
)
# Initialize k empty partitions
partitions
=
[[]
for
_
in
range
(
k
)]
partition_sums
=
[
0
]
*
k
# Assign each element to the partition with the smallest current sum
for
idx
in
sorted_indices
:
# Find the partition with the smallest sum
min_sum_idx
=
partition_sums
.
index
(
min
(
partition_sums
))
# Add the element to this partition
partitions
[
min_sum_idx
].
append
(
idx
)
# Store the original index
partition_sums
[
min_sum_idx
]
+=
arr
[
idx
][
1
]
return
partitions
def
process_pdf_batch
(
pdf_jobs
,
idx
):
"""Process a batch of PDF pages using multiple threads.
Parameters:
-----------
pdf_jobs : list of tuples
List of (pdf_path, page_num) tuples
output_dir : str or None
Directory to save images to
num_threads : int
Number of threads to use
**kwargs :
Additional arguments for process_pdf_page
Returns:
--------
images : list
List of processed images
"""
images
=
[]
for
pdf_path
,
_
in
pdf_jobs
:
doc
=
fitz
.
open
(
pdf_path
)
tmp
=
[]
for
page_num
in
range
(
len
(
doc
)):
page
=
doc
[
page_num
]
tmp
.
append
(
fitz_doc_to_image
(
page
))
images
.
append
(
tmp
)
return
(
idx
,
images
)
def
batch_build_dataset
(
pdf_paths
,
k
,
lang
=
None
):
"""Process multiple PDFs by partitioning them into k balanced parts and
processing each part in parallel.
Parameters:
-----------
pdf_paths : list
List of paths to PDF files
k : int
Number of partitions to create
output_dir : str or None
Directory to save images to
threads_per_worker : int
Number of threads to use per worker
**kwargs :
Additional arguments for process_pdf_page
Returns:
--------
all_images : list
List of all processed images
"""
# Get page counts for each PDF
pdf_info
=
[]
total_pages
=
0
for
pdf_path
in
pdf_paths
:
try
:
doc
=
fitz
.
open
(
pdf_path
)
num_pages
=
len
(
doc
)
pdf_info
.
append
((
pdf_path
,
num_pages
))
total_pages
+=
num_pages
doc
.
close
()
except
Exception
as
e
:
print
(
f
'Error opening
{
pdf_path
}
:
{
e
}
'
)
# Partition the jobs based on page countEach job has 1 page
partitions
=
partition_array_greedy
(
pdf_info
,
k
)
# Process each partition in parallel
all_images_h
=
{}
with
concurrent
.
futures
.
ProcessPoolExecutor
(
max_workers
=
k
)
as
executor
:
# Submit one task per partition
futures
=
[]
for
sn
,
partition
in
enumerate
(
partitions
):
# Get the jobs for this partition
partition_jobs
=
[
pdf_info
[
idx
]
for
idx
in
partition
]
# Submit the task
future
=
executor
.
submit
(
process_pdf_batch
,
partition_jobs
,
sn
)
futures
.
append
(
future
)
# Process results as they complete
for
i
,
future
in
enumerate
(
concurrent
.
futures
.
as_completed
(
futures
)):
try
:
idx
,
images
=
future
.
result
()
all_images_h
[
idx
]
=
images
except
Exception
as
e
:
print
(
f
'Error processing partition:
{
e
}
'
)
results
=
[
None
]
*
len
(
pdf_paths
)
for
i
in
range
(
len
(
partitions
)):
partition
=
partitions
[
i
]
for
j
in
range
(
len
(
partition
)):
with
open
(
pdf_info
[
partition
[
j
]][
0
],
'rb'
)
as
f
:
pdf_bytes
=
f
.
read
()
dataset
=
PymuDocDataset
(
pdf_bytes
,
lang
=
lang
)
dataset
.
set_images
(
all_images_h
[
i
][
j
])
results
[
partition
[
j
]]
=
dataset
return
results
magic_pdf/data/dataset.py
View file @
9ce72d78
...
@@ -97,7 +97,7 @@ class Dataset(ABC):
...
@@ -97,7 +97,7 @@ class Dataset(ABC):
@
abstractmethod
@
abstractmethod
def
dump_to_file
(
self
,
file_path
:
str
):
def
dump_to_file
(
self
,
file_path
:
str
):
"""Dump the file
"""Dump the file
.
Args:
Args:
file_path (str): the file path
file_path (str): the file path
...
@@ -119,7 +119,7 @@ class Dataset(ABC):
...
@@ -119,7 +119,7 @@ class Dataset(ABC):
@
abstractmethod
@
abstractmethod
def
classify
(
self
)
->
SupportedPdfParseMethod
:
def
classify
(
self
)
->
SupportedPdfParseMethod
:
"""classify the dataset
"""classify the dataset
.
Returns:
Returns:
SupportedPdfParseMethod: _description_
SupportedPdfParseMethod: _description_
...
@@ -128,8 +128,7 @@ class Dataset(ABC):
...
@@ -128,8 +128,7 @@ class Dataset(ABC):
@
abstractmethod
@
abstractmethod
def
clone
(
self
):
def
clone
(
self
):
"""clone this dataset
"""clone this dataset."""
"""
pass
pass
...
@@ -148,12 +147,14 @@ class PymuDocDataset(Dataset):
...
@@ -148,12 +147,14 @@ class PymuDocDataset(Dataset):
if
lang
==
''
:
if
lang
==
''
:
self
.
_lang
=
None
self
.
_lang
=
None
elif
lang
==
'auto'
:
elif
lang
==
'auto'
:
from
magic_pdf.model.sub_modules.language_detection.utils
import
auto_detect_lang
from
magic_pdf.model.sub_modules.language_detection.utils
import
\
auto_detect_lang
self
.
_lang
=
auto_detect_lang
(
bits
)
self
.
_lang
=
auto_detect_lang
(
bits
)
logger
.
info
(
f
"
lang:
{
lang
}
, detect_lang:
{
self
.
_lang
}
"
)
logger
.
info
(
f
'
lang:
{
lang
}
, detect_lang:
{
self
.
_lang
}
'
)
else
:
else
:
self
.
_lang
=
lang
self
.
_lang
=
lang
logger
.
info
(
f
"lang:
{
lang
}
"
)
logger
.
info
(
f
'lang:
{
lang
}
'
)
def
__len__
(
self
)
->
int
:
def
__len__
(
self
)
->
int
:
"""The page number of the pdf."""
"""The page number of the pdf."""
return
len
(
self
.
_records
)
return
len
(
self
.
_records
)
...
@@ -186,7 +187,7 @@ class PymuDocDataset(Dataset):
...
@@ -186,7 +187,7 @@ class PymuDocDataset(Dataset):
return
self
.
_records
[
page_id
]
return
self
.
_records
[
page_id
]
def
dump_to_file
(
self
,
file_path
:
str
):
def
dump_to_file
(
self
,
file_path
:
str
):
"""Dump the file
"""Dump the file
.
Args:
Args:
file_path (str): the file path
file_path (str): the file path
...
@@ -212,7 +213,7 @@ class PymuDocDataset(Dataset):
...
@@ -212,7 +213,7 @@ class PymuDocDataset(Dataset):
return
proc
(
self
,
*
args
,
**
kwargs
)
return
proc
(
self
,
*
args
,
**
kwargs
)
def
classify
(
self
)
->
SupportedPdfParseMethod
:
def
classify
(
self
)
->
SupportedPdfParseMethod
:
"""classify the dataset
"""classify the dataset
.
Returns:
Returns:
SupportedPdfParseMethod: _description_
SupportedPdfParseMethod: _description_
...
@@ -220,10 +221,12 @@ class PymuDocDataset(Dataset):
...
@@ -220,10 +221,12 @@ class PymuDocDataset(Dataset):
return
classify
(
self
.
_data_bits
)
return
classify
(
self
.
_data_bits
)
def
clone
(
self
):
def
clone
(
self
):
"""clone this dataset
"""clone this dataset."""
"""
return
PymuDocDataset
(
self
.
_raw_data
)
return
PymuDocDataset
(
self
.
_raw_data
)
def
set_images
(
self
,
images
):
for
i
in
range
(
len
(
self
.
_records
)):
self
.
_records
[
i
].
set_image
(
images
[
i
])
class
ImageDataset
(
Dataset
):
class
ImageDataset
(
Dataset
):
def
__init__
(
self
,
bits
:
bytes
):
def
__init__
(
self
,
bits
:
bytes
):
...
@@ -270,7 +273,7 @@ class ImageDataset(Dataset):
...
@@ -270,7 +273,7 @@ class ImageDataset(Dataset):
return
self
.
_records
[
page_id
]
return
self
.
_records
[
page_id
]
def
dump_to_file
(
self
,
file_path
:
str
):
def
dump_to_file
(
self
,
file_path
:
str
):
"""Dump the file
"""Dump the file
.
Args:
Args:
file_path (str): the file path
file_path (str): the file path
...
@@ -293,7 +296,7 @@ class ImageDataset(Dataset):
...
@@ -293,7 +296,7 @@ class ImageDataset(Dataset):
return
proc
(
self
,
*
args
,
**
kwargs
)
return
proc
(
self
,
*
args
,
**
kwargs
)
def
classify
(
self
)
->
SupportedPdfParseMethod
:
def
classify
(
self
)
->
SupportedPdfParseMethod
:
"""classify the dataset
"""classify the dataset
.
Returns:
Returns:
SupportedPdfParseMethod: _description_
SupportedPdfParseMethod: _description_
...
@@ -301,15 +304,19 @@ class ImageDataset(Dataset):
...
@@ -301,15 +304,19 @@ class ImageDataset(Dataset):
return
SupportedPdfParseMethod
.
OCR
return
SupportedPdfParseMethod
.
OCR
def
clone
(
self
):
def
clone
(
self
):
"""clone this dataset
"""clone this dataset."""
"""
return
ImageDataset
(
self
.
_raw_data
)
return
ImageDataset
(
self
.
_raw_data
)
def
set_images
(
self
,
images
):
for
i
in
range
(
len
(
self
.
_records
)):
self
.
_records
[
i
].
set_image
(
images
[
i
])
class
Doc
(
PageableData
):
class
Doc
(
PageableData
):
"""Initialized with pymudoc object."""
"""Initialized with pymudoc object."""
def
__init__
(
self
,
doc
:
fitz
.
Page
):
def
__init__
(
self
,
doc
:
fitz
.
Page
):
self
.
_doc
=
doc
self
.
_doc
=
doc
self
.
_img
=
None
def
get_image
(
self
):
def
get_image
(
self
):
"""Return the image info.
"""Return the image info.
...
@@ -321,7 +328,17 @@ class Doc(PageableData):
...
@@ -321,7 +328,17 @@ class Doc(PageableData):
height: int
height: int
}
}
"""
"""
return
fitz_doc_to_image
(
self
.
_doc
)
if
self
.
_img
is
None
:
self
.
_img
=
fitz_doc_to_image
(
self
.
_doc
)
return
self
.
_img
def
set_image
(
self
,
img
):
"""
Args:
img (np.ndarray): the image
"""
if
self
.
_img
is
None
:
self
.
_img
=
img
def
get_doc
(
self
)
->
fitz
.
Page
:
def
get_doc
(
self
)
->
fitz
.
Page
:
"""Get the pymudoc object.
"""Get the pymudoc object.
...
...
magic_pdf/data/utils.py
View file @
9ce72d78
import
multiprocessing
as
mp
import
threading
from
concurrent.futures
import
(
ProcessPoolExecutor
,
ThreadPoolExecutor
,
as_completed
)
import
fitz
import
fitz
import
numpy
as
np
import
numpy
as
np
from
loguru
import
logger
from
loguru
import
logger
from
magic_pdf.utils.annotations
import
ImportPIL
@
ImportPIL
def
fitz_doc_to_image
(
doc
,
dpi
=
200
)
->
dict
:
def
fitz_doc_to_image
(
doc
,
dpi
=
200
)
->
dict
:
"""Convert fitz.Document to image, Then convert the image to numpy array.
"""Convert fitz.Document to image, Then convert the image to numpy array.
...
@@ -17,7 +20,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
...
@@ -17,7 +20,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
Returns:
Returns:
dict: {'img': numpy array, 'width': width, 'height': height }
dict: {'img': numpy array, 'width': width, 'height': height }
"""
"""
from
PIL
import
Image
mat
=
fitz
.
Matrix
(
dpi
/
72
,
dpi
/
72
)
mat
=
fitz
.
Matrix
(
dpi
/
72
,
dpi
/
72
)
pm
=
doc
.
get_pixmap
(
matrix
=
mat
,
alpha
=
False
)
pm
=
doc
.
get_pixmap
(
matrix
=
mat
,
alpha
=
False
)
...
@@ -25,16 +27,14 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
...
@@ -25,16 +27,14 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
if
pm
.
width
>
4500
or
pm
.
height
>
4500
:
if
pm
.
width
>
4500
or
pm
.
height
>
4500
:
pm
=
doc
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
pm
=
doc
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
img
=
Image
.
frombytes
(
'RGB'
,
(
pm
.
width
,
pm
.
height
),
pm
.
samples
)
# Convert pixmap samples directly to numpy array
img
=
np
.
array
(
img
)
img
=
np
.
frombuffer
(
pm
.
samples
,
dtype
=
np
.
uint8
).
reshape
(
pm
.
height
,
pm
.
width
,
3
)
img_dict
=
{
'img'
:
img
,
'width'
:
pm
.
width
,
'height'
:
pm
.
height
}
img_dict
=
{
'img'
:
img
,
'width'
:
pm
.
width
,
'height'
:
pm
.
height
}
return
img_dict
return
img_dict
@
ImportPIL
def
load_images_from_pdf
(
pdf_bytes
:
bytes
,
dpi
=
200
,
start_page_id
=
0
,
end_page_id
=
None
)
->
list
:
def
load_images_from_pdf
(
pdf_bytes
:
bytes
,
dpi
=
200
,
start_page_id
=
0
,
end_page_id
=
None
)
->
list
:
from
PIL
import
Image
images
=
[]
images
=
[]
with
fitz
.
open
(
'pdf'
,
pdf_bytes
)
as
doc
:
with
fitz
.
open
(
'pdf'
,
pdf_bytes
)
as
doc
:
pdf_page_num
=
doc
.
page_count
pdf_page_num
=
doc
.
page_count
...
@@ -57,11 +57,110 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
...
@@ -57,11 +57,110 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
if
pm
.
width
>
4500
or
pm
.
height
>
4500
:
if
pm
.
width
>
4500
or
pm
.
height
>
4500
:
pm
=
page
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
pm
=
page
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
img
=
Image
.
frombytes
(
'RGB'
,
(
pm
.
width
,
pm
.
height
),
pm
.
samples
)
# Convert pixmap samples directly to numpy array
img
=
np
.
array
(
img
)
img
=
np
.
frombuffer
(
pm
.
samples
,
dtype
=
np
.
uint8
).
reshape
(
pm
.
height
,
pm
.
width
,
3
)
img_dict
=
{
'img'
:
img
,
'width'
:
pm
.
width
,
'height'
:
pm
.
height
}
img_dict
=
{
'img'
:
img
,
'width'
:
pm
.
width
,
'height'
:
pm
.
height
}
else
:
else
:
img_dict
=
{
'img'
:
[],
'width'
:
0
,
'height'
:
0
}
img_dict
=
{
'img'
:
[],
'width'
:
0
,
'height'
:
0
}
images
.
append
(
img_dict
)
images
.
append
(
img_dict
)
return
images
return
images
def
convert_page
(
bytes_page
):
pdfs
=
fitz
.
open
(
'pdf'
,
bytes_page
)
page
=
pdfs
[
0
]
return
fitz_doc_to_image
(
page
)
def
parallel_process_pdf_safe
(
pages
,
num_workers
=
None
,
**
kwargs
):
"""Process PDF pages in parallel with serialization-safe approach."""
if
num_workers
is
None
:
num_workers
=
mp
.
cpu_count
()
# Process the extracted page data in parallel
with
ProcessPoolExecutor
(
max_workers
=
num_workers
)
as
executor
:
# Process the page data
results
=
list
(
executor
.
map
(
convert_page
,
pages
)
)
return
results
def
threaded_process_pdf
(
pdf_path
,
num_threads
=
4
,
**
kwargs
):
"""Process all pages of a PDF using multiple threads.
Parameters:
-----------
pdf_path : str
Path to the PDF file
num_threads : int
Number of threads to use
**kwargs :
Additional arguments for fitz_doc_to_image
Returns:
--------
images : list
List of processed images, in page order
"""
# Open the PDF
doc
=
fitz
.
open
(
pdf_path
)
num_pages
=
len
(
doc
)
# Create a list to store results in the correct order
results
=
[
None
]
*
num_pages
# Create a thread pool
with
ThreadPoolExecutor
(
max_workers
=
num_threads
)
as
executor
:
# Submit all tasks
futures
=
{}
for
page_num
in
range
(
num_pages
):
page
=
doc
[
page_num
]
future
=
executor
.
submit
(
fitz_doc_to_image
,
page
,
**
kwargs
)
futures
[
future
]
=
page_num
# Process results as they complete with progress bar
for
future
in
as_completed
(
futures
):
page_num
=
futures
[
future
]
try
:
results
[
page_num
]
=
future
.
result
()
except
Exception
as
e
:
print
(
f
'Error processing page
{
page_num
}
:
{
e
}
'
)
results
[
page_num
]
=
None
# Close the document
doc
.
close
()
if
__name__
==
'__main__'
:
pdf
=
fitz
.
open
(
'/tmp/[MS-DOC].pdf'
)
pdf_page
=
[
fitz
.
open
()
for
i
in
range
(
pdf
.
page_count
)]
[
pdf_page
[
i
].
insert_pdf
(
pdf
,
from_page
=
i
,
to_page
=
i
)
for
i
in
range
(
pdf
.
page_count
)]
pdf_page
=
[
v
.
tobytes
()
for
v
in
pdf_page
]
results
=
parallel_process_pdf_safe
(
pdf_page
,
num_workers
=
16
)
# threaded_process_pdf('/tmp/[MS-DOC].pdf', num_threads=16)
""" benchmark results of multi-threaded processing (fitz page to image)
total page nums: 578
thread nums, time cost
1 7.351 sec
2 6.334 sec
4 5.968 sec
8 6.728 sec
16 8.085 sec
"""
""" benchmark results of multi-processor processing (fitz page to image)
total page nums: 578
processor nums, time cost
1 17.170 sec
2 10.170 sec
4 7.841 sec
8 7.900 sec
16 7.984 sec
"""
magic_pdf/dict2md/ocr_mkcontent.py
View file @
9ce72d78
...
@@ -208,12 +208,13 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason
...
@@ -208,12 +208,13 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason
'text'
:
merge_para_with_text
(
para_block
),
'text'
:
merge_para_with_text
(
para_block
),
}
}
elif
para_type
==
BlockType
.
Title
:
elif
para_type
==
BlockType
.
Title
:
title_level
=
get_title_level
(
para_block
)
para_content
=
{
para_content
=
{
'type'
:
'text'
,
'type'
:
'text'
,
'text'
:
merge_para_with_text
(
para_block
),
'text'
:
merge_para_with_text
(
para_block
),
'text_level'
:
title_level
,
}
}
title_level
=
get_title_level
(
para_block
)
if
title_level
!=
0
:
para_content
[
'text_level'
]
=
title_level
elif
para_type
==
BlockType
.
InterlineEquation
:
elif
para_type
==
BlockType
.
InterlineEquation
:
para_content
=
{
para_content
=
{
'type'
:
'equation'
,
'type'
:
'equation'
,
...
@@ -319,5 +320,5 @@ def get_title_level(block):
...
@@ -319,5 +320,5 @@ def get_title_level(block):
if
title_level
>
4
:
if
title_level
>
4
:
title_level
=
4
title_level
=
4
elif
title_level
<
1
:
elif
title_level
<
1
:
title_level
=
1
title_level
=
0
return
title_level
return
title_level
\ No newline at end of file
magic_pdf/libs/pdf_image_tools.py
View file @
9ce72d78
...
@@ -44,14 +44,19 @@ def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
...
@@ -44,14 +44,19 @@ def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
# 截取图片
# 截取图片
pix
=
page
.
get_pixmap
(
clip
=
rect
,
matrix
=
zoom
)
pix
=
page
.
get_pixmap
(
clip
=
rect
,
matrix
=
zoom
)
if
mode
==
"cv2"
:
# 直接转换为numpy数组供cv2使用
img_array
=
np
.
frombuffer
(
pix
.
samples
,
dtype
=
np
.
uint8
).
reshape
(
pix
.
height
,
pix
.
width
,
pix
.
n
)
# PyMuPDF使用RGB顺序,而cv2使用BGR顺序
if
pix
.
n
==
3
or
pix
.
n
==
4
:
image_result
=
cv2
.
cvtColor
(
img_array
,
cv2
.
COLOR_RGB2BGR
)
else
:
image_result
=
img_array
elif
mode
==
"pillow"
:
# 将字节数据转换为文件对象
# 将字节数据转换为文件对象
image_file
=
BytesIO
(
pix
.
tobytes
(
output
=
'png'
))
image_file
=
BytesIO
(
pix
.
tobytes
(
output
=
'png'
))
# 使用 Pillow 打开图像
# 使用 Pillow 打开图像
pil_image
=
Image
.
open
(
image_file
)
image_result
=
Image
.
open
(
image_file
)
if
mode
==
"cv2"
:
image_result
=
cv2
.
cvtColor
(
np
.
asarray
(
pil_image
),
cv2
.
COLOR_RGB2BGR
)
elif
mode
==
"pillow"
:
image_result
=
pil_image
else
:
else
:
raise
ValueError
(
f
"mode:
{
mode
}
is not supported."
)
raise
ValueError
(
f
"mode:
{
mode
}
is not supported."
)
...
...
magic_pdf/model/batch_analyze.py
View file @
9ce72d78
import
time
import
time
import
cv2
import
cv2
import
numpy
as
np
import
torch
import
torch
from
loguru
import
logger
from
loguru
import
logger
from
PIL
import
Image
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.config.constants
import
MODEL_NAME
# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
# from magic_pdf.data.dataset import Dataset
# from magic_pdf.libs.clean_memory import clean_memory
# from magic_pdf.libs.config_reader import get_device
# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from
magic_pdf.model.pdf_extract_kit
import
CustomPEKModel
from
magic_pdf.model.pdf_extract_kit
import
CustomPEKModel
from
magic_pdf.model.sub_modules.model_utils
import
(
from
magic_pdf.model.sub_modules.model_utils
import
(
clean_vram
,
crop_img
,
get_res_list_from_layout_res
)
clean_vram
,
crop_img
,
get_res_list_from_layout_res
)
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
# from magic_pdf.operators.models import InferenceResult
YOLO_LAYOUT_BASE_BATCH_SIZE
=
1
YOLO_LAYOUT_BASE_BATCH_SIZE
=
1
MFD_BASE_BATCH_SIZE
=
1
MFD_BASE_BATCH_SIZE
=
1
...
@@ -31,7 +23,6 @@ class BatchAnalyze:
...
@@ -31,7 +23,6 @@ class BatchAnalyze:
def
__call__
(
self
,
images
:
list
)
->
list
:
def
__call__
(
self
,
images
:
list
)
->
list
:
images_layout_res
=
[]
images_layout_res
=
[]
layout_start_time
=
time
.
time
()
layout_start_time
=
time
.
time
()
if
self
.
model
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
if
self
.
model
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
# layoutlmv3
# layoutlmv3
...
@@ -41,36 +32,14 @@ class BatchAnalyze:
...
@@ -41,36 +32,14 @@ class BatchAnalyze:
elif
self
.
model
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
elif
self
.
model
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
# doclayout_yolo
# doclayout_yolo
layout_images
=
[]
layout_images
=
[]
modified_images
=
[]
for
image_index
,
image
in
enumerate
(
images
):
for
image_index
,
image
in
enumerate
(
images
):
pil_img
=
Image
.
fromarray
(
image
)
layout_images
.
append
(
image
)
# width, height = pil_img.size
# if height > width:
# input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
# new_image, useful_list = crop_img(
# input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
# )
# layout_images.append(new_image)
# modified_images.append([image_index, useful_list])
# else:
layout_images
.
append
(
pil_img
)
images_layout_res
+=
self
.
model
.
layout_model
.
batch_predict
(
images_layout_res
+=
self
.
model
.
layout_model
.
batch_predict
(
# layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
# layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
layout_images
,
YOLO_LAYOUT_BASE_BATCH_SIZE
layout_images
,
YOLO_LAYOUT_BASE_BATCH_SIZE
)
)
for
image_index
,
useful_list
in
modified_images
:
for
res
in
images_layout_res
[
image_index
]:
for
i
in
range
(
len
(
res
[
'poly'
])):
if
i
%
2
==
0
:
res
[
'poly'
][
i
]
=
(
res
[
'poly'
][
i
]
-
useful_list
[
0
]
+
useful_list
[
2
]
)
else
:
res
[
'poly'
][
i
]
=
(
res
[
'poly'
][
i
]
-
useful_list
[
1
]
+
useful_list
[
3
]
)
logger
.
info
(
logger
.
info
(
f
'layout time:
{
round
(
time
.
time
()
-
layout_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
'
f
'layout time:
{
round
(
time
.
time
()
-
layout_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
'
)
)
...
@@ -111,7 +80,7 @@ class BatchAnalyze:
...
@@ -111,7 +80,7 @@ class BatchAnalyze:
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for
index
in
range
(
len
(
images
)):
for
index
in
range
(
len
(
images
)):
layout_res
=
images_layout_res
[
index
]
layout_res
=
images_layout_res
[
index
]
pil_img
=
Image
.
fromarray
(
images
[
index
]
)
np_array_img
=
images
[
index
]
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
=
(
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
=
(
get_res_list_from_layout_res
(
layout_res
)
get_res_list_from_layout_res
(
layout_res
)
...
@@ -121,14 +90,14 @@ class BatchAnalyze:
...
@@ -121,14 +90,14 @@ class BatchAnalyze:
# Process each area that requires OCR processing
# Process each area that requires OCR processing
for
res
in
ocr_res_list
:
for
res
in
ocr_res_list
:
new_image
,
useful_list
=
crop_img
(
new_image
,
useful_list
=
crop_img
(
res
,
pil
_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
res
,
np_array
_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
single_page_mfdetrec_res
,
useful_list
)
)
# OCR recognition
# OCR recognition
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
)
,
cv2
.
COLOR_RGB2BGR
)
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
if
self
.
model
.
apply_ocr
:
if
self
.
model
.
apply_ocr
:
ocr_res
=
self
.
model
.
ocr_model
.
ocr
(
ocr_res
=
self
.
model
.
ocr_model
.
ocr
(
...
@@ -150,7 +119,7 @@ class BatchAnalyze:
...
@@ -150,7 +119,7 @@ class BatchAnalyze:
if
self
.
model
.
apply_table
:
if
self
.
model
.
apply_table
:
table_start
=
time
.
time
()
table_start
=
time
.
time
()
for
res
in
table_res_list
:
for
res
in
table_res_list
:
new_image
,
_
=
crop_img
(
res
,
pil
_img
)
new_image
,
_
=
crop_img
(
res
,
np_array
_img
)
single_table_start_time
=
time
.
time
()
single_table_start_time
=
time
.
time
()
html_code
=
None
html_code
=
None
if
self
.
model
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
if
self
.
model
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
...
@@ -197,83 +166,3 @@ class BatchAnalyze:
...
@@ -197,83 +166,3 @@ class BatchAnalyze:
logger
.
info
(
f
'table time:
{
round
(
table_time
,
2
)
}
, image num:
{
table_count
}
'
)
logger
.
info
(
f
'table time:
{
round
(
table_time
,
2
)
}
, image num:
{
table_count
}
'
)
return
images_layout_res
return
images_layout_res
# def doc_batch_analyze(
# dataset: Dataset,
# ocr: bool = False,
# show_log: bool = False,
# start_page_id=0,
# end_page_id=None,
# lang=None,
# layout_model=None,
# formula_enable=None,
# table_enable=None,
# batch_ratio: int | None = None,
# ) -> InferenceResult:
# """Perform batch analysis on a document dataset.
#
# Args:
# dataset (Dataset): The dataset containing document pages to be analyzed.
# ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
# show_log (bool, optional): Flag to enable logging. Defaults to False.
# start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
# end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
# lang (str, optional): Language for OCR. Defaults to None.
# layout_model (optional): Layout model to be used for analysis. Defaults to None.
# formula_enable (optional): Flag to enable formula detection. Defaults to None.
# table_enable (optional): Flag to enable table detection. Defaults to None.
# batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
#
# Raises:
# CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
#
# Returns:
# InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
# """
#
# if not torch.cuda.is_available():
# raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
#
# lang = None if lang == '' else lang
# # TODO: auto detect batch size
# batch_ratio = 1 if batch_ratio is None else batch_ratio
# end_page_id = end_page_id if end_page_id else len(dataset)
#
# model_manager = ModelSingleton()
# custom_model: CustomPEKModel = model_manager.get_model(
# ocr, show_log, lang, layout_model, formula_enable, table_enable
# )
# batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
#
# model_json = []
#
# # batch analyze
# images = []
# for index in range(len(dataset)):
# if start_page_id <= index <= end_page_id:
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# images.append(img_dict['img'])
# analyze_result = batch_model(images)
#
# for index in range(len(dataset)):
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# page_width = img_dict['width']
# page_height = img_dict['height']
# if start_page_id <= index <= end_page_id:
# result = analyze_result.pop(0)
# else:
# result = []
#
# page_info = {'page_no': index, 'height': page_height, 'width': page_width}
# page_dict = {'layout_dets': result, 'page_info': page_info}
# model_json.append(page_dict)
#
# # TODO: clean memory when gpu memory is not enough
# clean_memory_start_time = time.time()
# clean_memory(get_device())
# logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
#
# return InferenceResult(model_json, dataset)
magic_pdf/model/doc_analyze_by_custom_model.py
View file @
9ce72d78
import
concurrent.futures
as
fut
import
multiprocessing
as
mp
import
os
import
os
import
time
import
time
import
numpy
as
np
import
torch
import
torch
os
.
environ
[
'FLAGS_npu_jit_compile'
]
=
'0'
# 关闭paddle的jit编译
os
.
environ
[
'FLAGS_npu_jit_compile'
]
=
'0'
# 关闭paddle的jit编译
os
.
environ
[
'FLAGS_use_stride_kernel'
]
=
'0'
os
.
environ
[
'FLAGS_use_stride_kernel'
]
=
'0'
os
.
environ
[
'PYTORCH_ENABLE_MPS_FALLBACK'
]
=
'1'
# 让mps可以fallback
os
.
environ
[
'PYTORCH_ENABLE_MPS_FALLBACK'
]
=
'1'
# 让mps可以fallback
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
# 关闭paddle的信号处理
import
paddle
paddle
.
disable_signal_handler
()
from
loguru
import
logger
from
loguru
import
logger
from
magic_pdf.model.batch_analyze
import
BatchAnalyze
from
magic_pdf.model.sub_modules.model_utils
import
get_vram
from
magic_pdf.model.sub_modules.model_utils
import
get_vram
try
:
try
:
...
@@ -30,8 +31,8 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
...
@@ -30,8 +31,8 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_local_models_dir
,
get_local_models_dir
,
get_table_recog_config
)
get_table_recog_config
)
from
magic_pdf.model.model_list
import
MODEL
from
magic_pdf.model.model_list
import
MODEL
from
magic_pdf.operators.models
import
InferenceResult
# from magic_pdf.operators.models import InferenceResult
class
ModelSingleton
:
class
ModelSingleton
:
_instance
=
None
_instance
=
None
...
@@ -72,9 +73,7 @@ def custom_model_init(
...
@@ -72,9 +73,7 @@ def custom_model_init(
formula_enable
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
table_enable
=
None
,
):
):
model
=
None
model
=
None
if
model_config
.
__model_mode__
==
'lite'
:
if
model_config
.
__model_mode__
==
'lite'
:
logger
.
warning
(
logger
.
warning
(
'The Lite mode is provided for developers to conduct testing only, and the output quality is '
'The Lite mode is provided for developers to conduct testing only, and the output quality is '
...
@@ -132,7 +131,6 @@ def custom_model_init(
...
@@ -132,7 +131,6 @@ def custom_model_init(
return
custom_model
return
custom_model
def
doc_analyze
(
def
doc_analyze
(
dataset
:
Dataset
,
dataset
:
Dataset
,
ocr
:
bool
=
False
,
ocr
:
bool
=
False
,
...
@@ -143,14 +141,112 @@ def doc_analyze(
...
@@ -143,14 +141,112 @@ def doc_analyze(
layout_model
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
table_enable
=
None
,
)
->
InferenceResult
:
):
end_page_id
=
(
end_page_id
=
(
end_page_id
end_page_id
if
end_page_id
is
not
None
and
end_page_id
>=
0
if
end_page_id
is
not
None
and
end_page_id
>=
0
else
len
(
dataset
)
-
1
else
len
(
dataset
)
-
1
)
)
MIN_BATCH_INFERENCE_SIZE
=
int
(
os
.
environ
.
get
(
'MINERU_MIN_BATCH_INFERENCE_SIZE'
,
100
))
images
=
[]
page_wh_list
=
[]
for
index
in
range
(
len
(
dataset
)):
if
start_page_id
<=
index
<=
end_page_id
:
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
'img'
])
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
if
len
(
images
)
>=
MIN_BATCH_INFERENCE_SIZE
:
batch_size
=
MIN_BATCH_INFERENCE_SIZE
batch_images
=
[
images
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images
),
batch_size
)]
else
:
batch_images
=
[
images
]
results
=
[]
for
sn
,
batch_image
in
enumerate
(
batch_images
):
_
,
result
=
may_batch_image_analyze
(
batch_image
,
sn
,
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
results
.
extend
(
result
)
model_json
=
[]
for
index
in
range
(
len
(
dataset
)):
if
start_page_id
<=
index
<=
end_page_id
:
result
=
results
.
pop
(
0
)
page_width
,
page_height
=
page_wh_list
.
pop
(
0
)
else
:
result
=
[]
page_height
=
0
page_width
=
0
page_info
=
{
'page_no'
:
index
,
'width'
:
page_width
,
'height'
:
page_height
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
model_json
.
append
(
page_dict
)
from
magic_pdf.operators.models
import
InferenceResult
return
InferenceResult
(
model_json
,
dataset
)
def
batch_doc_analyze
(
datasets
:
list
[
Dataset
],
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
):
MIN_BATCH_INFERENCE_SIZE
=
int
(
os
.
environ
.
get
(
'MINERU_MIN_BATCH_INFERENCE_SIZE'
,
100
))
images
=
[]
page_wh_list
=
[]
for
dataset
in
datasets
:
for
index
in
range
(
len
(
dataset
)):
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
'img'
])
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
if
len
(
images
)
>=
MIN_BATCH_INFERENCE_SIZE
:
batch_size
=
MIN_BATCH_INFERENCE_SIZE
batch_images
=
[
images
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images
),
batch_size
)]
else
:
batch_images
=
[
images
]
results
=
[]
for
sn
,
batch_image
in
enumerate
(
batch_images
):
_
,
result
=
may_batch_image_analyze
(
batch_image
,
sn
,
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
results
.
extend
(
result
)
infer_results
=
[]
from
magic_pdf.operators.models
import
InferenceResult
for
index
in
range
(
len
(
datasets
)):
dataset
=
datasets
[
index
]
model_json
=
[]
for
i
in
range
(
len
(
dataset
)):
result
=
results
.
pop
(
0
)
page_width
,
page_height
=
page_wh_list
.
pop
(
0
)
page_info
=
{
'page_no'
:
i
,
'width'
:
page_width
,
'height'
:
page_height
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
model_json
.
append
(
page_dict
)
infer_results
.
append
(
InferenceResult
(
model_json
,
dataset
))
return
infer_results
def
may_batch_image_analyze
(
images
:
list
[
np
.
ndarray
],
idx
:
int
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
# os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
# 关闭paddle的信号处理
import
paddle
paddle
.
disable_signal_handler
()
from
magic_pdf.model.batch_analyze
import
BatchAnalyze
model_manager
=
ModelSingleton
()
model_manager
=
ModelSingleton
()
custom_model
=
model_manager
.
get_model
(
custom_model
=
model_manager
.
get_model
(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
...
@@ -160,33 +256,32 @@ def doc_analyze(
...
@@ -160,33 +256,32 @@ def doc_analyze(
batch_ratio
=
1
batch_ratio
=
1
device
=
get_device
()
device
=
get_device
()
npu_support
=
False
if
str
(
device
).
startswith
(
'npu'
):
if
str
(
device
).
startswith
(
"npu"
):
import
torch_npu
import
torch_npu
if
torch_npu
.
npu
.
is_available
():
if
torch_npu
.
npu
.
is_available
():
npu_support
=
True
torch
.
npu
.
set_compile_mode
(
jit_compile
=
False
)
torch
.
npu
.
set_compile_mode
(
jit_compile
=
False
)
if
torch
.
cuda
.
is_available
()
and
device
!=
'cpu'
or
npu_support
:
if
str
(
device
).
startswith
(
'npu'
)
or
str
(
device
).
startswith
(
'cuda'
)
:
gpu_memory
=
int
(
os
.
getenv
(
"
VIRTUAL_VRAM_SIZE
"
,
round
(
get_vram
(
device
))))
gpu_memory
=
int
(
os
.
getenv
(
'
VIRTUAL_VRAM_SIZE
'
,
round
(
get_vram
(
device
))))
if
gpu_memory
is
not
None
and
gpu_memory
>=
8
:
if
gpu_memory
is
not
None
:
if
gpu_memory
>=
20
:
if
gpu_memory
>=
20
:
batch_ratio
=
16
batch_ratio
=
16
elif
gpu_memory
>=
15
:
elif
gpu_memory
>=
15
:
batch_ratio
=
8
batch_ratio
=
8
elif
gpu_memory
>=
10
:
elif
gpu_memory
>=
10
:
batch_ratio
=
4
batch_ratio
=
4
el
se
:
el
if
gpu_memory
>=
7
:
batch_ratio
=
2
batch_ratio
=
2
else
:
batch_ratio
=
1
logger
.
info
(
f
'gpu_memory:
{
gpu_memory
}
GB, batch_ratio:
{
batch_ratio
}
'
)
logger
.
info
(
f
'gpu_memory:
{
gpu_memory
}
GB, batch_ratio:
{
batch_ratio
}
'
)
batch_analyze
=
True
batch_analyze
=
True
elif
str
(
device
).
startswith
(
'mps'
):
model_json
=
[]
batch_analyze
=
True
doc_analyze_start
=
time
.
time
()
doc_analyze_start
=
time
.
time
()
if
batch_analyze
:
if
batch_analyze
:
# batch analyze
"""
# batch analyze
images = []
images = []
page_wh_list = []
page_wh_list = []
for index in range(len(dataset)):
for index in range(len(dataset)):
...
@@ -195,9 +290,10 @@ def doc_analyze(
...
@@ -195,9 +290,10 @@ def doc_analyze(
img_dict = page_data.get_image()
img_dict = page_data.get_image()
images.append(img_dict['img'])
images.append(img_dict['img'])
page_wh_list.append((img_dict['width'], img_dict['height']))
page_wh_list.append((img_dict['width'], img_dict['height']))
"""
batch_model
=
BatchAnalyze
(
model
=
custom_model
,
batch_ratio
=
batch_ratio
)
batch_model
=
BatchAnalyze
(
model
=
custom_model
,
batch_ratio
=
batch_ratio
)
analyze_
result
=
batch_model
(
images
)
result
s
=
batch_model
(
images
)
"""
for index in range(len(dataset)):
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0)
result = analyze_result.pop(0)
...
@@ -210,10 +306,10 @@ def doc_analyze(
...
@@ -210,10 +306,10 @@ def doc_analyze(
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
model_json.append(page_dict)
"""
else
:
else
:
# single analyze
# single analyze
"""
for index in range(len(dataset)):
for index in range(len(dataset)):
page_data = dataset.get_page(index)
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
img_dict = page_data.get_image()
...
@@ -230,6 +326,13 @@ def doc_analyze(
...
@@ -230,6 +326,13 @@ def doc_analyze(
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_info = {'page_no': index, 'width': page_width, 'height': page_height}
page_dict = {'layout_dets': result, 'page_info': page_info}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
model_json.append(page_dict)
"""
results
=
[]
for
img_idx
,
img
in
enumerate
(
images
):
inference_start
=
time
.
time
()
result
=
custom_model
(
img
)
logger
.
info
(
f
'-----image index :
{
img_idx
}
, image inference total time:
{
round
(
time
.
time
()
-
inference_start
,
2
)
}
-----'
)
results
.
append
(
result
)
gc_start
=
time
.
time
()
gc_start
=
time
.
time
()
clean_memory
(
get_device
())
clean_memory
(
get_device
())
...
@@ -237,10 +340,9 @@ def doc_analyze(
...
@@ -237,10 +340,9 @@ def doc_analyze(
logger
.
info
(
f
'gc time:
{
gc_time
}
'
)
logger
.
info
(
f
'gc time:
{
gc_time
}
'
)
doc_analyze_time
=
round
(
time
.
time
()
-
doc_analyze_start
,
2
)
doc_analyze_time
=
round
(
time
.
time
()
-
doc_analyze_start
,
2
)
doc_analyze_speed
=
round
(
(
en
d_page_id
+
1
-
start_page_id
)
/
doc_analyze_time
,
2
)
doc_analyze_speed
=
round
(
l
en
(
images
)
/
doc_analyze_time
,
2
)
logger
.
info
(
logger
.
info
(
f
'doc analyze time:
{
round
(
time
.
time
()
-
doc_analyze_start
,
2
)
}
,'
f
'doc analyze time:
{
round
(
time
.
time
()
-
doc_analyze_start
,
2
)
}
,'
f
' speed:
{
doc_analyze_speed
}
pages/second'
f
' speed:
{
doc_analyze_speed
}
pages/second'
)
)
return
(
idx
,
results
)
return
InferenceResult
(
model_json
,
dataset
)
magic_pdf/model/pdf_extract_kit.py
View file @
9ce72d78
...
@@ -3,11 +3,9 @@ import os
...
@@ -3,11 +3,9 @@ import os
import
time
import
time
import
cv2
import
cv2
import
numpy
as
np
import
torch
import
torch
import
yaml
import
yaml
from
loguru
import
logger
from
loguru
import
logger
from
PIL
import
Image
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
...
@@ -120,7 +118,7 @@ class CustomPEKModel:
...
@@ -120,7 +118,7 @@ class CustomPEKModel:
atom_model_name
=
AtomicModel
.
MFR
,
atom_model_name
=
AtomicModel
.
MFR
,
mfr_weight_dir
=
mfr_weight_dir
,
mfr_weight_dir
=
mfr_weight_dir
,
mfr_cfg_path
=
mfr_cfg_path
,
mfr_cfg_path
=
mfr_cfg_path
,
device
=
'cpu'
if
str
(
self
.
device
).
startswith
(
"mps"
)
else
self
.
device
,
device
=
self
.
device
,
)
)
# 初始化layout模型
# 初始化layout模型
...
@@ -174,11 +172,6 @@ class CustomPEKModel:
...
@@ -174,11 +172,6 @@ class CustomPEKModel:
logger
.
info
(
'DocAnalysis init done!'
)
logger
.
info
(
'DocAnalysis init done!'
)
def
__call__
(
self
,
image
):
def
__call__
(
self
,
image
):
pil_img
=
Image
.
fromarray
(
image
)
width
,
height
=
pil_img
.
size
# logger.info(f'width: {width}, height: {height}')
# layout检测
# layout检测
layout_start
=
time
.
time
()
layout_start
=
time
.
time
()
layout_res
=
[]
layout_res
=
[]
...
@@ -186,24 +179,6 @@ class CustomPEKModel:
...
@@ -186,24 +179,6 @@ class CustomPEKModel:
# layoutlmv3
# layoutlmv3
layout_res
=
self
.
layout_model
(
image
,
ignore_catids
=
[])
layout_res
=
self
.
layout_model
(
image
,
ignore_catids
=
[])
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
# doclayout_yolo
# if height > width:
# input_res = {"poly":[0,0,width,0,width,height,0,height]}
# new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
# paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# layout_res = self.layout_model.predict(new_image)
# for res in layout_res:
# p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
# p1 = p1 - paste_x + xmin
# p2 = p2 - paste_y + ymin
# p3 = p3 - paste_x + xmin
# p4 = p4 - paste_y + ymin
# p5 = p5 - paste_x + xmin
# p6 = p6 - paste_y + ymin
# p7 = p7 - paste_x + xmin
# p8 = p8 - paste_y + ymin
# res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
# else:
layout_res
=
self
.
layout_model
.
predict
(
image
)
layout_res
=
self
.
layout_model
.
predict
(
image
)
layout_cost
=
round
(
time
.
time
()
-
layout_start
,
2
)
layout_cost
=
round
(
time
.
time
()
-
layout_start
,
2
)
...
@@ -234,11 +209,11 @@ class CustomPEKModel:
...
@@ -234,11 +209,11 @@ class CustomPEKModel:
ocr_start
=
time
.
time
()
ocr_start
=
time
.
time
()
# Process each area that requires OCR processing
# Process each area that requires OCR processing
for
res
in
ocr_res_list
:
for
res
in
ocr_res_list
:
new_image
,
useful_list
=
crop_img
(
res
,
pil_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
new_image
,
useful_list
=
crop_img
(
res
,
image
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
# OCR recognition
# OCR recognition
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
)
,
cv2
.
COLOR_RGB2BGR
)
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
if
self
.
apply_ocr
:
if
self
.
apply_ocr
:
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
...
@@ -260,7 +235,7 @@ class CustomPEKModel:
...
@@ -260,7 +235,7 @@ class CustomPEKModel:
if
self
.
apply_table
:
if
self
.
apply_table
:
table_start
=
time
.
time
()
table_start
=
time
.
time
()
for
res
in
table_res_list
:
for
res
in
table_res_list
:
new_image
,
_
=
crop_img
(
res
,
pil_img
)
new_image
,
_
=
crop_img
(
res
,
image
)
single_table_start_time
=
time
.
time
()
single_table_start_time
=
time
.
time
()
html_code
=
None
html_code
=
None
if
self
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
if
self
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
...
...
magic_pdf/model/sub_modules/language_detection/utils.py
View file @
9ce72d78
...
@@ -3,8 +3,6 @@ import os
...
@@ -3,8 +3,6 @@ import os
from
pathlib
import
Path
from
pathlib
import
Path
import
yaml
import
yaml
from
PIL
import
Image
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.config.constants
import
MODEL_NAME
...
@@ -42,7 +40,7 @@ def get_text_images(simple_images):
...
@@ -42,7 +40,7 @@ def get_text_images(simple_images):
)
)
text_images
=
[]
text_images
=
[]
for
simple_image
in
simple_images
:
for
simple_image
in
simple_images
:
image
=
Image
.
fromarray
(
simple_image
[
'img'
]
)
image
=
simple_image
[
'img'
]
layout_res
=
temp_layout_model
.
predict
(
image
)
layout_res
=
temp_layout_model
.
predict
(
image
)
# 给textblock截图
# 给textblock截图
for
res
in
layout_res
:
for
res
in
layout_res
:
...
@@ -51,7 +49,7 @@ def get_text_images(simple_images):
...
@@ -51,7 +49,7 @@ def get_text_images(simple_images):
# 初步清洗(宽和高都小于100)
# 初步清洗(宽和高都小于100)
if
x2
-
x1
<
100
and
y2
-
y1
<
100
:
if
x2
-
x1
<
100
and
y2
-
y1
<
100
:
continue
continue
text_images
.
append
(
image
.
crop
((
x1
,
y1
,
x2
,
y2
))
)
text_images
.
append
(
image
[
y1
:
y2
,
x1
:
x2
]
)
return
text_images
return
text_images
...
...
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py
View file @
9ce72d78
...
@@ -2,9 +2,9 @@
...
@@ -2,9 +2,9 @@
import
time
import
time
from
collections
import
Counter
from
collections
import
Counter
from
uuid
import
uuid4
from
uuid
import
uuid4
import
cv2
import
numpy
as
np
import
torch
import
torch
from
PIL
import
Image
from
loguru
import
logger
from
loguru
import
logger
from
ultralytics
import
YOLO
from
ultralytics
import
YOLO
...
@@ -29,7 +29,7 @@ def split_images(image, result_images=None):
...
@@ -29,7 +29,7 @@ def split_images(image, result_images=None):
if
result_images
is
None
:
if
result_images
is
None
:
result_images
=
[]
result_images
=
[]
width
,
height
=
image
.
s
ize
height
,
width
=
image
.
s
hape
[:
2
]
long_side
=
max
(
width
,
height
)
# 获取较长边长度
long_side
=
max
(
width
,
height
)
# 获取较长边长度
if
long_side
<=
400
:
if
long_side
<=
400
:
...
@@ -44,16 +44,14 @@ def split_images(image, result_images=None):
...
@@ -44,16 +44,14 @@ def split_images(image, result_images=None):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if
x
+
new_long_side
>
width
:
if
x
+
new_long_side
>
width
:
continue
continue
box
=
(
x
,
0
,
x
+
new_long_side
,
height
)
sub_image
=
image
[
0
:
height
,
x
:
x
+
new_long_side
]
sub_image
=
image
.
crop
(
box
)
sub_images
.
append
(
sub_image
)
sub_images
.
append
(
sub_image
)
else
:
# 如果高度是较长边
else
:
# 如果高度是较长边
for
y
in
range
(
0
,
height
,
new_long_side
):
for
y
in
range
(
0
,
height
,
new_long_side
):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if
y
+
new_long_side
>
height
:
if
y
+
new_long_side
>
height
:
continue
continue
box
=
(
0
,
y
,
width
,
y
+
new_long_side
)
sub_image
=
image
[
y
:
y
+
new_long_side
,
0
:
width
]
sub_image
=
image
.
crop
(
box
)
sub_images
.
append
(
sub_image
)
sub_images
.
append
(
sub_image
)
for
sub_image
in
sub_images
:
for
sub_image
in
sub_images
:
...
@@ -64,24 +62,32 @@ def split_images(image, result_images=None):
...
@@ -64,24 +62,32 @@ def split_images(image, result_images=None):
def
resize_images_to_224
(
image
):
def
resize_images_to_224
(
image
):
"""
"""
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小。
Works directly with NumPy arrays.
"""
"""
try
:
try
:
width
,
height
=
image
.
size
height
,
width
=
image
.
shape
[:
2
]
if
width
<
224
or
height
<
224
:
if
width
<
224
or
height
<
224
:
new_image
=
Image
.
new
(
'RGB'
,
(
224
,
224
),
(
0
,
0
,
0
))
# Create black background
paste_x
=
(
224
-
width
)
//
2
new_image
=
np
.
zeros
((
224
,
224
,
3
),
dtype
=
np
.
uint8
)
paste_y
=
(
224
-
height
)
//
2
# Calculate paste position (ensure they're not negative)
new_image
.
paste
(
image
,
(
paste_x
,
paste_y
))
paste_x
=
max
(
0
,
(
224
-
width
)
//
2
)
paste_y
=
max
(
0
,
(
224
-
height
)
//
2
)
# Make sure we don't exceed the boundaries of new_image
paste_width
=
min
(
width
,
224
)
paste_height
=
min
(
height
,
224
)
# Paste original image onto black background
new_image
[
paste_y
:
paste_y
+
paste_height
,
paste_x
:
paste_x
+
paste_width
]
=
image
[:
paste_height
,
:
paste_width
]
image
=
new_image
image
=
new_image
else
:
else
:
image
=
image
.
resize
((
224
,
224
),
Image
.
Resampling
.
LANCZOS
)
# Resize using cv2
image
=
cv2
.
resize
(
image
,
(
224
,
224
),
interpolation
=
cv2
.
INTER_LANCZOS4
)
# uuid = str(uuid4())
# image.save(f"/tmp/{uuid}.jpg")
return
image
return
image
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
exception
(
e
)
logger
.
exception
(
f
"Error in resize_images_to_224:
{
e
}
"
)
return
None
class
YOLOv11LangDetModel
(
object
):
class
YOLOv11LangDetModel
(
object
):
...
@@ -96,8 +102,7 @@ class YOLOv11LangDetModel(object):
...
@@ -96,8 +102,7 @@ class YOLOv11LangDetModel(object):
def
do_detect
(
self
,
images
:
list
):
def
do_detect
(
self
,
images
:
list
):
all_images
=
[]
all_images
=
[]
for
image
in
images
:
for
image
in
images
:
width
,
height
=
image
.
size
height
,
width
=
image
.
shape
[:
2
]
# logger.info(f"image size: {width} x {height}")
if
width
<
100
and
height
<
100
:
if
width
<
100
and
height
<
100
:
continue
continue
temp_images
=
split_images
(
image
)
temp_images
=
split_images
(
image
)
...
...
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
View file @
9ce72d78
import
argparse
import
os
import
re
import
torch
import
torch
import
unimernet.tasks
as
tasks
from
PIL
import
Image
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data
import
DataLoader
,
Dataset
from
torchvision
import
transforms
from
unimernet.common.config
import
Config
from
unimernet.processors
import
load_processor
class
MathDataset
(
Dataset
):
class
MathDataset
(
Dataset
):
...
@@ -20,55 +11,25 @@ class MathDataset(Dataset):
...
@@ -20,55 +11,25 @@ class MathDataset(Dataset):
return
len
(
self
.
image_paths
)
return
len
(
self
.
image_paths
)
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
# if not pil image, then convert to pil image
if
isinstance
(
self
.
image_paths
[
idx
],
str
):
raw_image
=
Image
.
open
(
self
.
image_paths
[
idx
])
else
:
raw_image
=
self
.
image_paths
[
idx
]
raw_image
=
self
.
image_paths
[
idx
]
if
self
.
transform
:
if
self
.
transform
:
image
=
self
.
transform
(
raw_image
)
image
=
self
.
transform
(
raw_image
)
return
image
return
image
def
latex_rm_whitespace
(
s
:
str
):
"""Remove unnecessary whitespace from LaTeX code."""
text_reg
=
r
"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
letter
=
"[a-zA-Z]"
noletter
=
"[\W_^\d]"
names
=
[
x
[
0
].
replace
(
" "
,
""
)
for
x
in
re
.
findall
(
text_reg
,
s
)]
s
=
re
.
sub
(
text_reg
,
lambda
match
:
str
(
names
.
pop
(
0
)),
s
)
news
=
s
while
True
:
s
=
news
news
=
re
.
sub
(
r
"(?!\\ )(%s)\s+?(%s)"
%
(
noletter
,
noletter
),
r
"\1\2"
,
s
)
news
=
re
.
sub
(
r
"(?!\\ )(%s)\s+?(%s)"
%
(
noletter
,
letter
),
r
"\1\2"
,
news
)
news
=
re
.
sub
(
r
"(%s)\s+?(%s)"
%
(
letter
,
noletter
),
r
"\1\2"
,
news
)
if
news
==
s
:
break
return
s
class
UnimernetModel
(
object
):
class
UnimernetModel
(
object
):
def
__init__
(
self
,
weight_dir
,
cfg_path
,
_device_
=
"cpu"
):
def
__init__
(
self
,
weight_dir
,
cfg_path
,
_device_
=
"cpu"
):
args
=
argparse
.
Namespace
(
cfg_path
=
cfg_path
,
options
=
None
)
from
.unimernet_hf
import
UnimernetModel
cfg
=
Config
(
args
)
if
_device_
.
startswith
(
"mps"
):
cfg
.
config
.
model
.
pretrained
=
os
.
path
.
join
(
weight_dir
,
"pytorch_model.pth"
)
self
.
model
=
UnimernetModel
.
from_pretrained
(
weight_dir
,
attn_implementation
=
"eager"
)
cfg
.
config
.
model
.
model_config
.
model_name
=
weight_dir
else
:
cfg
.
config
.
model
.
tokenizer_config
.
path
=
weight_dir
self
.
model
=
UnimernetModel
.
from_pretrained
(
weight_dir
)
task
=
tasks
.
setup_task
(
cfg
)
self
.
model
=
task
.
build_model
(
cfg
)
self
.
device
=
_device_
self
.
device
=
_device_
self
.
model
.
to
(
_device_
)
self
.
model
.
to
(
_device_
)
if
not
_device_
.
startswith
(
"cpu"
):
self
.
model
=
self
.
model
.
to
(
dtype
=
torch
.
float16
)
self
.
model
.
eval
()
self
.
model
.
eval
()
vis_processor
=
load_processor
(
"formula_image_eval"
,
cfg
.
config
.
datasets
.
formula_rec_eval
.
vis_processor
.
eval
,
)
self
.
mfr_transform
=
transforms
.
Compose
(
[
vis_processor
,
]
)
def
predict
(
self
,
mfd_res
,
image
):
def
predict
(
self
,
mfd_res
,
image
):
formula_list
=
[]
formula_list
=
[]
...
@@ -84,62 +45,22 @@ class UnimernetModel(object):
...
@@ -84,62 +45,22 @@ class UnimernetModel(object):
"latex"
:
""
,
"latex"
:
""
,
}
}
formula_list
.
append
(
new_item
)
formula_list
.
append
(
new_item
)
pil_img
=
Image
.
fromarray
(
image
)
bbox_img
=
image
[
ymin
:
ymax
,
xmin
:
xmax
]
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
ymax
))
mf_image_list
.
append
(
bbox_img
)
mf_image_list
.
append
(
bbox_img
)
dataset
=
MathDataset
(
mf_image_list
,
transform
=
self
.
m
fr_
transform
)
dataset
=
MathDataset
(
mf_image_list
,
transform
=
self
.
m
odel
.
transform
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
32
,
num_workers
=
0
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
32
,
num_workers
=
0
)
mfr_res
=
[]
mfr_res
=
[]
for
mf_img
in
dataloader
:
for
mf_img
in
dataloader
:
mf_img
=
mf_img
.
to
(
dtype
=
self
.
model
.
dtype
)
mf_img
=
mf_img
.
to
(
self
.
device
)
mf_img
=
mf_img
.
to
(
self
.
device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
self
.
model
.
generate
({
"image"
:
mf_img
})
output
=
self
.
model
.
generate
({
"image"
:
mf_img
})
mfr_res
.
extend
(
output
[
"
pr
ed_str"
])
mfr_res
.
extend
(
output
[
"
fix
ed_str"
])
for
res
,
latex
in
zip
(
formula_list
,
mfr_res
):
for
res
,
latex
in
zip
(
formula_list
,
mfr_res
):
res
[
"latex"
]
=
latex
_rm_whitespace
(
latex
)
res
[
"latex"
]
=
latex
return
formula_list
return
formula_list
# def batch_predict(
# self, images_mfd_res: list, images: list, batch_size: int = 64
# ) -> list:
# images_formula_list = []
# mf_image_list = []
# backfill_list = []
# for image_index in range(len(images_mfd_res)):
# mfd_res = images_mfd_res[image_index]
# pil_img = Image.fromarray(images[image_index])
# formula_list = []
#
# for xyxy, conf, cla in zip(
# mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
# ):
# xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
# new_item = {
# "category_id": 13 + int(cla.item()),
# "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
# "score": round(float(conf.item()), 2),
# "latex": "",
# }
# formula_list.append(new_item)
# bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
# mf_image_list.append(bbox_img)
#
# images_formula_list.append(formula_list)
# backfill_list += formula_list
#
# dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# mfr_res = []
# for mf_img in dataloader:
# mf_img = mf_img.to(self.device)
# with torch.no_grad():
# output = self.model.generate({"image": mf_img})
# mfr_res.extend(output["pred_str"])
# for res, latex in zip(backfill_list, mfr_res):
# res["latex"] = latex_rm_whitespace(latex)
# return images_formula_list
def
batch_predict
(
self
,
images_mfd_res
:
list
,
images
:
list
,
batch_size
:
int
=
64
)
->
list
:
def
batch_predict
(
self
,
images_mfd_res
:
list
,
images
:
list
,
batch_size
:
int
=
64
)
->
list
:
images_formula_list
=
[]
images_formula_list
=
[]
mf_image_list
=
[]
mf_image_list
=
[]
...
@@ -149,7 +70,7 @@ class UnimernetModel(object):
...
@@ -149,7 +70,7 @@ class UnimernetModel(object):
# Collect images with their original indices
# Collect images with their original indices
for
image_index
in
range
(
len
(
images_mfd_res
)):
for
image_index
in
range
(
len
(
images_mfd_res
)):
mfd_res
=
images_mfd_res
[
image_index
]
mfd_res
=
images_mfd_res
[
image_index
]
pil_img
=
Image
.
fromarray
(
images
[
image_index
]
)
np_array_image
=
images
[
image_index
]
formula_list
=
[]
formula_list
=
[]
for
idx
,
(
xyxy
,
conf
,
cla
)
in
enumerate
(
zip
(
for
idx
,
(
xyxy
,
conf
,
cla
)
in
enumerate
(
zip
(
...
@@ -163,7 +84,7 @@ class UnimernetModel(object):
...
@@ -163,7 +84,7 @@ class UnimernetModel(object):
"latex"
:
""
,
"latex"
:
""
,
}
}
formula_list
.
append
(
new_item
)
formula_list
.
append
(
new_item
)
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
y
max
))
bbox_img
=
np_array_image
[
ymin
:
ymax
,
xmin
:
x
max
]
area
=
(
xmax
-
xmin
)
*
(
ymax
-
ymin
)
area
=
(
xmax
-
xmin
)
*
(
ymax
-
ymin
)
curr_idx
=
len
(
mf_image_list
)
curr_idx
=
len
(
mf_image_list
)
...
@@ -182,22 +103,23 @@ class UnimernetModel(object):
...
@@ -182,22 +103,23 @@ class UnimernetModel(object):
index_mapping
=
{
new_idx
:
old_idx
for
new_idx
,
old_idx
in
enumerate
(
sorted_indices
)}
index_mapping
=
{
new_idx
:
old_idx
for
new_idx
,
old_idx
in
enumerate
(
sorted_indices
)}
# Create dataset with sorted images
# Create dataset with sorted images
dataset
=
MathDataset
(
sorted_images
,
transform
=
self
.
m
fr_
transform
)
dataset
=
MathDataset
(
sorted_images
,
transform
=
self
.
m
odel
.
transform
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
num_workers
=
0
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
num_workers
=
0
)
# Process batches and store results
# Process batches and store results
mfr_res
=
[]
mfr_res
=
[]
for
mf_img
in
dataloader
:
for
mf_img
in
dataloader
:
mf_img
=
mf_img
.
to
(
dtype
=
self
.
model
.
dtype
)
mf_img
=
mf_img
.
to
(
self
.
device
)
mf_img
=
mf_img
.
to
(
self
.
device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
self
.
model
.
generate
({
"image"
:
mf_img
})
output
=
self
.
model
.
generate
({
"image"
:
mf_img
})
mfr_res
.
extend
(
output
[
"
pr
ed_str"
])
mfr_res
.
extend
(
output
[
"
fix
ed_str"
])
# Restore original order
# Restore original order
unsorted_results
=
[
""
]
*
len
(
mfr_res
)
unsorted_results
=
[
""
]
*
len
(
mfr_res
)
for
new_idx
,
latex
in
enumerate
(
mfr_res
):
for
new_idx
,
latex
in
enumerate
(
mfr_res
):
original_idx
=
index_mapping
[
new_idx
]
original_idx
=
index_mapping
[
new_idx
]
unsorted_results
[
original_idx
]
=
latex
_rm_whitespace
(
latex
)
unsorted_results
[
original_idx
]
=
latex
# Fill results back
# Fill results back
for
res
,
latex
in
zip
(
backfill_list
,
unsorted_results
):
for
res
,
latex
in
zip
(
backfill_list
,
unsorted_results
):
...
...
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py
0 → 100644
View file @
9ce72d78
from
.unimer_swin
import
UnimerSwinConfig
,
UnimerSwinModel
,
UnimerSwinImageProcessor
from
.unimer_mbart
import
UnimerMBartConfig
,
UnimerMBartModel
,
UnimerMBartForCausalLM
from
.modeling_unimernet
import
UnimernetModel
__all__
=
[
"UnimerSwinConfig"
,
"UnimerSwinModel"
,
"UnimerSwinImageProcessor"
,
"UnimerMBartConfig"
,
"UnimerMBartModel"
,
"UnimerMBartForCausalLM"
,
"UnimernetModel"
,
]
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py
0 → 100644
View file @
9ce72d78
import
os
import
re
import
warnings
from
typing
import
Optional
import
torch
from
ftfy
import
fix_text
from
transformers
import
AutoConfig
,
AutoModel
,
AutoModelForCausalLM
,
AutoTokenizer
,
PretrainedConfig
,
PreTrainedModel
from
transformers
import
VisionEncoderDecoderConfig
,
VisionEncoderDecoderModel
from
transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder
import
logger
as
base_model_logger
from
.unimer_swin
import
UnimerSwinConfig
,
UnimerSwinModel
,
UnimerSwinImageProcessor
from
.unimer_mbart
import
UnimerMBartConfig
,
UnimerMBartForCausalLM
AutoConfig
.
register
(
UnimerSwinConfig
.
model_type
,
UnimerSwinConfig
)
AutoConfig
.
register
(
UnimerMBartConfig
.
model_type
,
UnimerMBartConfig
)
AutoModel
.
register
(
UnimerSwinConfig
,
UnimerSwinModel
)
AutoModelForCausalLM
.
register
(
UnimerMBartConfig
,
UnimerMBartForCausalLM
)
# TODO: rewrite tokenizer
class
TokenizerWrapper
:
def
__init__
(
self
,
tokenizer
):
self
.
tokenizer
=
tokenizer
self
.
pad_token_id
=
self
.
tokenizer
.
pad_token_id
self
.
bos_token_id
=
self
.
tokenizer
.
bos_token_id
self
.
eos_token_id
=
self
.
tokenizer
.
eos_token_id
def
__len__
(
self
):
return
len
(
self
.
tokenizer
)
def
tokenize
(
self
,
text
,
**
kwargs
):
return
self
.
tokenizer
(
text
,
return_token_type_ids
=
False
,
return_tensors
=
"pt"
,
padding
=
"longest"
,
truncation
=
True
,
**
kwargs
,
)
def
token2str
(
self
,
tokens
)
->
list
:
generated_text
=
self
.
tokenizer
.
batch_decode
(
tokens
,
skip_special_tokens
=
True
)
generated_text
=
[
fix_text
(
text
)
for
text
in
generated_text
]
return
generated_text
def
detokenize
(
self
,
tokens
):
toks
=
[
self
.
tokenizer
.
convert_ids_to_tokens
(
tok
)
for
tok
in
tokens
]
for
b
in
range
(
len
(
toks
)):
for
i
in
reversed
(
range
(
len
(
toks
[
b
]))):
if
toks
[
b
][
i
]
is
None
:
toks
[
b
][
i
]
=
''
toks
[
b
][
i
]
=
toks
[
b
][
i
].
replace
(
'Ġ'
,
' '
).
strip
()
if
toks
[
b
][
i
]
in
([
self
.
tokenizer
.
bos_token
,
self
.
tokenizer
.
eos_token
,
self
.
tokenizer
.
pad_token
]):
del
toks
[
b
][
i
]
return
toks
def
latex_rm_whitespace
(
s
:
str
):
"""Remove unnecessary whitespace from LaTeX code.
"""
text_reg
=
r
'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
letter
=
r
'[a-zA-Z]'
noletter
=
r
'[\W_^\d]'
names
=
[
x
[
0
].
replace
(
' '
,
''
)
for
x
in
re
.
findall
(
text_reg
,
s
)]
s
=
re
.
sub
(
text_reg
,
lambda
_
:
str
(
names
.
pop
(
0
)),
s
)
news
=
s
while
True
:
s
=
news
news
=
re
.
sub
(
r
'(?!\\ )(%s)\s+?(%s)'
%
(
noletter
,
noletter
),
r
'\1\2'
,
s
)
news
=
re
.
sub
(
r
'(?!\\ )(%s)\s+?(%s)'
%
(
noletter
,
letter
),
r
'\1\2'
,
news
)
news
=
re
.
sub
(
r
'(%s)\s+?(%s)'
%
(
letter
,
noletter
),
r
'\1\2'
,
news
)
if
news
==
s
:
break
return
s
class
UnimernetModel
(
VisionEncoderDecoderModel
):
def
__init__
(
self
,
config
:
Optional
[
PretrainedConfig
]
=
None
,
encoder
:
Optional
[
PreTrainedModel
]
=
None
,
decoder
:
Optional
[
PreTrainedModel
]
=
None
,
):
# VisionEncoderDecoderModel's checking log has bug, disable for temp.
base_model_logger
.
disabled
=
True
try
:
super
().
__init__
(
config
,
encoder
,
decoder
)
finally
:
base_model_logger
.
disabled
=
False
if
not
config
or
not
hasattr
(
config
,
"_name_or_path"
):
raise
RuntimeError
(
"config._name_or_path is required by UnimernetModel."
)
model_path
=
config
.
_name_or_path
self
.
transform
=
UnimerSwinImageProcessor
()
self
.
tokenizer
=
TokenizerWrapper
(
AutoTokenizer
.
from_pretrained
(
model_path
))
self
.
_post_check
()
def
_post_check
(
self
):
tokenizer
=
self
.
tokenizer
if
tokenizer
.
tokenizer
.
model_max_length
!=
self
.
config
.
decoder
.
max_position_embeddings
:
warnings
.
warn
(
f
"decoder.max_position_embeddings=
{
self
.
config
.
decoder
.
max_position_embeddings
}
,"
+
f
" but tokenizer.model_max_length=
{
tokenizer
.
tokenizer
.
model_max_length
}
, will set"
+
f
" tokenizer.model_max_length to
{
self
.
config
.
decoder
.
max_position_embeddings
}
."
)
tokenizer
.
tokenizer
.
model_max_length
=
self
.
config
.
decoder
.
max_position_embeddings
assert
self
.
config
.
decoder
.
vocab_size
==
len
(
tokenizer
)
assert
self
.
config
.
decoder_start_token_id
==
tokenizer
.
bos_token_id
assert
self
.
config
.
pad_token_id
==
tokenizer
.
pad_token_id
@
classmethod
def
from_checkpoint
(
cls
,
model_path
:
str
,
model_filename
:
str
=
"pytorch_model.pth"
,
state_dict_strip_prefix
=
"model.model."
):
config
=
VisionEncoderDecoderConfig
.
from_pretrained
(
model_path
)
config
.
_name_or_path
=
model_path
config
.
encoder
=
UnimerSwinConfig
(
**
vars
(
config
.
encoder
))
config
.
decoder
=
UnimerMBartConfig
(
**
vars
(
config
.
decoder
))
encoder
=
UnimerSwinModel
(
config
.
encoder
)
decoder
=
UnimerMBartForCausalLM
(
config
.
decoder
)
model
=
cls
(
config
,
encoder
,
decoder
)
# load model weights
model_file_path
=
os
.
path
.
join
(
model_path
,
model_filename
)
checkpoint
=
torch
.
load
(
model_file_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
state_dict
=
checkpoint
[
"model"
]
if
"model"
in
checkpoint
else
checkpoint
if
not
state_dict
:
raise
RuntimeError
(
"state_dict is empty."
)
if
state_dict_strip_prefix
:
state_dict
=
{
k
[
len
(
state_dict_strip_prefix
):]
if
k
.
startswith
(
state_dict_strip_prefix
)
else
k
:
v
for
k
,
v
in
state_dict
.
items
()
}
missing_keys
,
unexpected_keys
=
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
if
len
(
unexpected_keys
)
>
0
:
warnings
.
warn
(
"Unexpected key(s) in state_dict: {}."
.
format
(
", "
.
join
(
f
'"
{
k
}
"'
for
k
in
unexpected_keys
)))
if
len
(
missing_keys
)
>
0
:
raise
RuntimeError
(
"Missing key(s) in state_dict: {}."
.
format
(
", "
.
join
(
f
'"
{
k
}
"'
for
k
in
missing_keys
)))
return
model
def
forward_bak
(
self
,
samples
):
pixel_values
,
text
=
samples
[
"image"
],
samples
[
"text_input"
]
text_inputs
=
self
.
tokenizer
.
tokenize
(
text
).
to
(
pixel_values
.
device
)
decoder_input_ids
,
decoder_attention_mask
=
text_inputs
[
"input_ids"
],
text_inputs
[
"attention_mask"
]
num_channels
=
pixel_values
.
shape
[
1
]
if
num_channels
==
1
:
pixel_values
=
pixel_values
.
repeat
(
1
,
3
,
1
,
1
)
labels
=
decoder_input_ids
*
1
labels
=
labels
.
masked_fill
(
labels
==
self
.
tokenizer
.
pad_token_id
,
-
100
)
loss
=
self
.
model
(
pixel_values
=
pixel_values
,
decoder_input_ids
=
decoder_input_ids
[:,
:
-
1
],
decoder_attention_mask
=
decoder_attention_mask
[:,
:
-
1
],
labels
=
labels
[:,
1
:],
).
loss
return
{
"loss"
:
loss
}
def
generate
(
self
,
samples
,
do_sample
:
bool
=
False
,
temperature
:
float
=
0.2
,
top_p
:
float
=
0.95
):
pixel_values
=
samples
[
"image"
]
num_channels
=
pixel_values
.
shape
[
1
]
if
num_channels
==
1
:
pixel_values
=
pixel_values
.
repeat
(
1
,
3
,
1
,
1
)
kwargs
=
{}
if
do_sample
:
kwargs
[
"temperature"
]
=
temperature
kwargs
[
"top_p"
]
=
top_p
outputs
=
super
().
generate
(
pixel_values
=
pixel_values
,
max_new_tokens
=
self
.
tokenizer
.
tokenizer
.
model_max_length
,
# required
decoder_start_token_id
=
self
.
tokenizer
.
tokenizer
.
bos_token_id
,
do_sample
=
do_sample
,
**
kwargs
,
)
outputs
=
outputs
[:,
1
:].
cpu
().
numpy
()
pred_tokens
=
self
.
tokenizer
.
detokenize
(
outputs
)
pred_str
=
self
.
tokenizer
.
token2str
(
outputs
)
fixed_str
=
[
latex_rm_whitespace
(
s
)
for
s
in
pred_str
]
return
{
"pred_ids"
:
outputs
,
"pred_tokens"
:
pred_tokens
,
"pred_str"
:
pred_str
,
"fixed_str"
:
fixed_str
}
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py
0 → 100644
View file @
9ce72d78
from
.configuration_unimer_mbart
import
UnimerMBartConfig
from
.modeling_unimer_mbart
import
UnimerMBartModel
,
UnimerMBartForCausalLM
__all__
=
[
"UnimerMBartConfig"
,
"UnimerMBartModel"
,
"UnimerMBartForCausalLM"
,
]
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py
0 → 100644
View file @
9ce72d78
# coding=utf-8
# Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""UnimerMBART model configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
class
UnimerMBartConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`MBartModel`]. It is used to instantiate an MBART
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the MBART
[facebook/mbart-large-cc25](https://huggingface.co/facebook/mbart-large-cc25) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50265):
Vocabulary size of the MBART model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MBartModel`] or [`TFMBartModel`].
d_model (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
qk_squeeze (`int`, *optional*, defaults to 2):
Squeeze ratio for query/key's output dimension. See the [UniMERNet paper](https://arxiv.org/abs/2404.15254).
Squeeze Attention maps the query and key to a lower-dimensional space without excessive loss of information,
thereby accelerating the computation of attention.
encoder_layers (`int`, *optional*, defaults to 12):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 12):
Number of decoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
scale_embedding (`bool`, *optional*, defaults to `False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models)
forced_eos_token_id (`int`, *optional*, defaults to 2):
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
`eos_token_id`.
Example:
```python
>>> from transformers import MBartConfig, MBartModel
>>> # Initializing a MBART facebook/mbart-large-cc25 style configuration
>>> configuration = MBartConfig()
>>> # Initializing a model (with random weights) from the facebook/mbart-large-cc25 style configuration
>>> model = MBartModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"unimer-mbart"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
attribute_map
=
{
"num_attention_heads"
:
"encoder_attention_heads"
,
"hidden_size"
:
"d_model"
}
def
__init__
(
self
,
vocab_size
=
50265
,
max_position_embeddings
=
1024
,
encoder_layers
=
12
,
encoder_ffn_dim
=
4096
,
encoder_attention_heads
=
16
,
decoder_layers
=
12
,
decoder_ffn_dim
=
4096
,
decoder_attention_heads
=
16
,
encoder_layerdrop
=
0.0
,
decoder_layerdrop
=
0.0
,
use_cache
=
True
,
is_encoder_decoder
=
True
,
activation_function
=
"gelu"
,
d_model
=
1024
,
qk_squeeze
=
2
,
dropout
=
0.1
,
attention_dropout
=
0.0
,
activation_dropout
=
0.0
,
init_std
=
0.02
,
classifier_dropout
=
0.0
,
scale_embedding
=
False
,
pad_token_id
=
1
,
bos_token_id
=
0
,
eos_token_id
=
2
,
forced_eos_token_id
=
2
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
d_model
=
d_model
self
.
qk_squeeze
=
qk_squeeze
self
.
encoder_ffn_dim
=
encoder_ffn_dim
self
.
encoder_layers
=
encoder_layers
self
.
encoder_attention_heads
=
encoder_attention_heads
self
.
decoder_ffn_dim
=
decoder_ffn_dim
self
.
decoder_layers
=
decoder_layers
self
.
decoder_attention_heads
=
decoder_attention_heads
self
.
dropout
=
dropout
self
.
attention_dropout
=
attention_dropout
self
.
activation_dropout
=
activation_dropout
self
.
activation_function
=
activation_function
self
.
init_std
=
init_std
self
.
encoder_layerdrop
=
encoder_layerdrop
self
.
decoder_layerdrop
=
decoder_layerdrop
self
.
classifier_dropout
=
classifier_dropout
self
.
use_cache
=
use_cache
self
.
num_hidden_layers
=
encoder_layers
self
.
scale_embedding
=
scale_embedding
# scale factor will be sqrt(d_model) if True
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
is_encoder_decoder
=
is_encoder_decoder
,
forced_eos_token_id
=
forced_eos_token_id
,
**
kwargs
,
)
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py
0 → 100644
View file @
9ce72d78
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment