Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
41d96cd8
Unverified
Commit
41d96cd8
authored
Apr 03, 2025
by
Xiaomeng Zhao
Committed by
GitHub
Apr 03, 2025
Browse files
Merge pull request #2065 from opendatalab/release-1.3.0
Release 1.3.0
parents
c3d43e52
dd96663c
Changes
126
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3292 additions
and
501 deletions
+3292
-501
magic_pdf/data/dataset.py
magic_pdf/data/dataset.py
+44
-24
magic_pdf/data/utils.py
magic_pdf/data/utils.py
+108
-9
magic_pdf/dict2md/ocr_mkcontent.py
magic_pdf/dict2md/ocr_mkcontent.py
+4
-3
magic_pdf/libs/pdf_image_tools.py
magic_pdf/libs/pdf_image_tools.py
+11
-6
magic_pdf/libs/performance_stats.py
magic_pdf/libs/performance_stats.py
+12
-1
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+175
-201
magic_pdf/model/doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+137
-92
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+5
-38
magic_pdf/model/sub_modules/language_detection/utils.py
magic_pdf/model/sub_modules/language_detection/utils.py
+2
-4
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py
...f/model/sub_modules/language_detection/yolov11/YOLOv11.py
+24
-19
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py
.../model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py
+3
-1
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
+3
-1
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
+31
-102
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py
.../model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py
+13
-0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py
..._modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py
+189
-0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py
...dules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py
+8
-0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py
...t/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py
+163
-0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py
...mernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py
+2351
-0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py
...et/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py
+0
-0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py
...odules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py
+9
-0
No files found.
magic_pdf/data/dataset.py
View file @
41d96cd8
...
@@ -97,10 +97,10 @@ class Dataset(ABC):
...
@@ -97,10 +97,10 @@ class Dataset(ABC):
@
abstractmethod
@
abstractmethod
def
dump_to_file
(
self
,
file_path
:
str
):
def
dump_to_file
(
self
,
file_path
:
str
):
"""Dump the file
"""Dump the file
.
Args:
Args:
file_path (str): the file path
file_path (str): the file path
"""
"""
pass
pass
...
@@ -119,7 +119,7 @@ class Dataset(ABC):
...
@@ -119,7 +119,7 @@ class Dataset(ABC):
@
abstractmethod
@
abstractmethod
def
classify
(
self
)
->
SupportedPdfParseMethod
:
def
classify
(
self
)
->
SupportedPdfParseMethod
:
"""classify the dataset
"""classify the dataset
.
Returns:
Returns:
SupportedPdfParseMethod: _description_
SupportedPdfParseMethod: _description_
...
@@ -128,8 +128,7 @@ class Dataset(ABC):
...
@@ -128,8 +128,7 @@ class Dataset(ABC):
@
abstractmethod
@
abstractmethod
def
clone
(
self
):
def
clone
(
self
):
"""clone this dataset
"""clone this dataset."""
"""
pass
pass
...
@@ -144,16 +143,19 @@ class PymuDocDataset(Dataset):
...
@@ -144,16 +143,19 @@ class PymuDocDataset(Dataset):
self
.
_records
=
[
Doc
(
v
)
for
v
in
self
.
_raw_fitz
]
self
.
_records
=
[
Doc
(
v
)
for
v
in
self
.
_raw_fitz
]
self
.
_data_bits
=
bits
self
.
_data_bits
=
bits
self
.
_raw_data
=
bits
self
.
_raw_data
=
bits
self
.
_classify_result
=
None
if
lang
==
''
:
if
lang
==
''
:
self
.
_lang
=
None
self
.
_lang
=
None
elif
lang
==
'auto'
:
elif
lang
==
'auto'
:
from
magic_pdf.model.sub_modules.language_detection.utils
import
auto_detect_lang
from
magic_pdf.model.sub_modules.language_detection.utils
import
\
auto_detect_lang
self
.
_lang
=
auto_detect_lang
(
bits
)
self
.
_lang
=
auto_detect_lang
(
bits
)
logger
.
info
(
f
"
lang:
{
lang
}
, detect_lang:
{
self
.
_lang
}
"
)
logger
.
info
(
f
'
lang:
{
lang
}
, detect_lang:
{
self
.
_lang
}
'
)
else
:
else
:
self
.
_lang
=
lang
self
.
_lang
=
lang
logger
.
info
(
f
"lang:
{
lang
}
"
)
logger
.
info
(
f
'lang:
{
lang
}
'
)
def
__len__
(
self
)
->
int
:
def
__len__
(
self
)
->
int
:
"""The page number of the pdf."""
"""The page number of the pdf."""
return
len
(
self
.
_records
)
return
len
(
self
.
_records
)
...
@@ -186,12 +188,12 @@ class PymuDocDataset(Dataset):
...
@@ -186,12 +188,12 @@ class PymuDocDataset(Dataset):
return
self
.
_records
[
page_id
]
return
self
.
_records
[
page_id
]
def
dump_to_file
(
self
,
file_path
:
str
):
def
dump_to_file
(
self
,
file_path
:
str
):
"""Dump the file
"""Dump the file
.
Args:
Args:
file_path (str): the file path
file_path (str): the file path
"""
"""
dir_name
=
os
.
path
.
dirname
(
file_path
)
dir_name
=
os
.
path
.
dirname
(
file_path
)
if
dir_name
not
in
(
''
,
'.'
,
'..'
):
if
dir_name
not
in
(
''
,
'.'
,
'..'
):
os
.
makedirs
(
dir_name
,
exist_ok
=
True
)
os
.
makedirs
(
dir_name
,
exist_ok
=
True
)
...
@@ -212,18 +214,22 @@ class PymuDocDataset(Dataset):
...
@@ -212,18 +214,22 @@ class PymuDocDataset(Dataset):
return
proc
(
self
,
*
args
,
**
kwargs
)
return
proc
(
self
,
*
args
,
**
kwargs
)
def
classify
(
self
)
->
SupportedPdfParseMethod
:
def
classify
(
self
)
->
SupportedPdfParseMethod
:
"""classify the dataset
"""classify the dataset
.
Returns:
Returns:
SupportedPdfParseMethod: _description_
SupportedPdfParseMethod: _description_
"""
"""
return
classify
(
self
.
_data_bits
)
if
self
.
_classify_result
is
None
:
self
.
_classify_result
=
classify
(
self
.
_data_bits
)
return
self
.
_classify_result
def
clone
(
self
):
def
clone
(
self
):
"""clone this dataset
"""clone this dataset."""
"""
return
PymuDocDataset
(
self
.
_raw_data
)
return
PymuDocDataset
(
self
.
_raw_data
)
def
set_images
(
self
,
images
):
for
i
in
range
(
len
(
self
.
_records
)):
self
.
_records
[
i
].
set_image
(
images
[
i
])
class
ImageDataset
(
Dataset
):
class
ImageDataset
(
Dataset
):
def
__init__
(
self
,
bits
:
bytes
):
def
__init__
(
self
,
bits
:
bytes
):
...
@@ -270,10 +276,10 @@ class ImageDataset(Dataset):
...
@@ -270,10 +276,10 @@ class ImageDataset(Dataset):
return
self
.
_records
[
page_id
]
return
self
.
_records
[
page_id
]
def
dump_to_file
(
self
,
file_path
:
str
):
def
dump_to_file
(
self
,
file_path
:
str
):
"""Dump the file
"""Dump the file
.
Args:
Args:
file_path (str): the file path
file_path (str): the file path
"""
"""
dir_name
=
os
.
path
.
dirname
(
file_path
)
dir_name
=
os
.
path
.
dirname
(
file_path
)
if
dir_name
not
in
(
''
,
'.'
,
'..'
):
if
dir_name
not
in
(
''
,
'.'
,
'..'
):
...
@@ -293,7 +299,7 @@ class ImageDataset(Dataset):
...
@@ -293,7 +299,7 @@ class ImageDataset(Dataset):
return
proc
(
self
,
*
args
,
**
kwargs
)
return
proc
(
self
,
*
args
,
**
kwargs
)
def
classify
(
self
)
->
SupportedPdfParseMethod
:
def
classify
(
self
)
->
SupportedPdfParseMethod
:
"""classify the dataset
"""classify the dataset
.
Returns:
Returns:
SupportedPdfParseMethod: _description_
SupportedPdfParseMethod: _description_
...
@@ -301,15 +307,19 @@ class ImageDataset(Dataset):
...
@@ -301,15 +307,19 @@ class ImageDataset(Dataset):
return
SupportedPdfParseMethod
.
OCR
return
SupportedPdfParseMethod
.
OCR
def
clone
(
self
):
def
clone
(
self
):
"""clone this dataset
"""clone this dataset."""
"""
return
ImageDataset
(
self
.
_raw_data
)
return
ImageDataset
(
self
.
_raw_data
)
def
set_images
(
self
,
images
):
for
i
in
range
(
len
(
self
.
_records
)):
self
.
_records
[
i
].
set_image
(
images
[
i
])
class
Doc
(
PageableData
):
class
Doc
(
PageableData
):
"""Initialized with pymudoc object."""
"""Initialized with pymudoc object."""
def
__init__
(
self
,
doc
:
fitz
.
Page
):
def
__init__
(
self
,
doc
:
fitz
.
Page
):
self
.
_doc
=
doc
self
.
_doc
=
doc
self
.
_img
=
None
def
get_image
(
self
):
def
get_image
(
self
):
"""Return the image info.
"""Return the image info.
...
@@ -321,7 +331,17 @@ class Doc(PageableData):
...
@@ -321,7 +331,17 @@ class Doc(PageableData):
height: int
height: int
}
}
"""
"""
return
fitz_doc_to_image
(
self
.
_doc
)
if
self
.
_img
is
None
:
self
.
_img
=
fitz_doc_to_image
(
self
.
_doc
)
return
self
.
_img
def
set_image
(
self
,
img
):
"""
Args:
img (np.ndarray): the image
"""
if
self
.
_img
is
None
:
self
.
_img
=
img
def
get_doc
(
self
)
->
fitz
.
Page
:
def
get_doc
(
self
)
->
fitz
.
Page
:
"""Get the pymudoc object.
"""Get the pymudoc object.
...
...
magic_pdf/data/utils.py
View file @
41d96cd8
import
multiprocessing
as
mp
import
threading
from
concurrent.futures
import
(
ProcessPoolExecutor
,
ThreadPoolExecutor
,
as_completed
)
import
fitz
import
fitz
import
numpy
as
np
import
numpy
as
np
from
loguru
import
logger
from
loguru
import
logger
from
magic_pdf.utils.annotations
import
ImportPIL
@
ImportPIL
def
fitz_doc_to_image
(
doc
,
dpi
=
200
)
->
dict
:
def
fitz_doc_to_image
(
doc
,
dpi
=
200
)
->
dict
:
"""Convert fitz.Document to image, Then convert the image to numpy array.
"""Convert fitz.Document to image, Then convert the image to numpy array.
...
@@ -17,7 +20,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
...
@@ -17,7 +20,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
Returns:
Returns:
dict: {'img': numpy array, 'width': width, 'height': height }
dict: {'img': numpy array, 'width': width, 'height': height }
"""
"""
from
PIL
import
Image
mat
=
fitz
.
Matrix
(
dpi
/
72
,
dpi
/
72
)
mat
=
fitz
.
Matrix
(
dpi
/
72
,
dpi
/
72
)
pm
=
doc
.
get_pixmap
(
matrix
=
mat
,
alpha
=
False
)
pm
=
doc
.
get_pixmap
(
matrix
=
mat
,
alpha
=
False
)
...
@@ -25,16 +27,14 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
...
@@ -25,16 +27,14 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
if
pm
.
width
>
4500
or
pm
.
height
>
4500
:
if
pm
.
width
>
4500
or
pm
.
height
>
4500
:
pm
=
doc
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
pm
=
doc
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
img
=
Image
.
frombytes
(
'RGB'
,
(
pm
.
width
,
pm
.
height
),
pm
.
samples
)
# Convert pixmap samples directly to numpy array
img
=
np
.
array
(
img
)
img
=
np
.
frombuffer
(
pm
.
samples
,
dtype
=
np
.
uint8
).
reshape
(
pm
.
height
,
pm
.
width
,
3
)
img_dict
=
{
'img'
:
img
,
'width'
:
pm
.
width
,
'height'
:
pm
.
height
}
img_dict
=
{
'img'
:
img
,
'width'
:
pm
.
width
,
'height'
:
pm
.
height
}
return
img_dict
return
img_dict
@
ImportPIL
def
load_images_from_pdf
(
pdf_bytes
:
bytes
,
dpi
=
200
,
start_page_id
=
0
,
end_page_id
=
None
)
->
list
:
def
load_images_from_pdf
(
pdf_bytes
:
bytes
,
dpi
=
200
,
start_page_id
=
0
,
end_page_id
=
None
)
->
list
:
from
PIL
import
Image
images
=
[]
images
=
[]
with
fitz
.
open
(
'pdf'
,
pdf_bytes
)
as
doc
:
with
fitz
.
open
(
'pdf'
,
pdf_bytes
)
as
doc
:
pdf_page_num
=
doc
.
page_count
pdf_page_num
=
doc
.
page_count
...
@@ -57,11 +57,110 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
...
@@ -57,11 +57,110 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
if
pm
.
width
>
4500
or
pm
.
height
>
4500
:
if
pm
.
width
>
4500
or
pm
.
height
>
4500
:
pm
=
page
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
pm
=
page
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
img
=
Image
.
frombytes
(
'RGB'
,
(
pm
.
width
,
pm
.
height
),
pm
.
samples
)
# Convert pixmap samples directly to numpy array
img
=
np
.
array
(
img
)
img
=
np
.
frombuffer
(
pm
.
samples
,
dtype
=
np
.
uint8
).
reshape
(
pm
.
height
,
pm
.
width
,
3
)
img_dict
=
{
'img'
:
img
,
'width'
:
pm
.
width
,
'height'
:
pm
.
height
}
img_dict
=
{
'img'
:
img
,
'width'
:
pm
.
width
,
'height'
:
pm
.
height
}
else
:
else
:
img_dict
=
{
'img'
:
[],
'width'
:
0
,
'height'
:
0
}
img_dict
=
{
'img'
:
[],
'width'
:
0
,
'height'
:
0
}
images
.
append
(
img_dict
)
images
.
append
(
img_dict
)
return
images
return
images
def
convert_page
(
bytes_page
):
pdfs
=
fitz
.
open
(
'pdf'
,
bytes_page
)
page
=
pdfs
[
0
]
return
fitz_doc_to_image
(
page
)
def
parallel_process_pdf_safe
(
pages
,
num_workers
=
None
,
**
kwargs
):
"""Process PDF pages in parallel with serialization-safe approach."""
if
num_workers
is
None
:
num_workers
=
mp
.
cpu_count
()
# Process the extracted page data in parallel
with
ProcessPoolExecutor
(
max_workers
=
num_workers
)
as
executor
:
# Process the page data
results
=
list
(
executor
.
map
(
convert_page
,
pages
)
)
return
results
def
threaded_process_pdf
(
pdf_path
,
num_threads
=
4
,
**
kwargs
):
"""Process all pages of a PDF using multiple threads.
Parameters:
-----------
pdf_path : str
Path to the PDF file
num_threads : int
Number of threads to use
**kwargs :
Additional arguments for fitz_doc_to_image
Returns:
--------
images : list
List of processed images, in page order
"""
# Open the PDF
doc
=
fitz
.
open
(
pdf_path
)
num_pages
=
len
(
doc
)
# Create a list to store results in the correct order
results
=
[
None
]
*
num_pages
# Create a thread pool
with
ThreadPoolExecutor
(
max_workers
=
num_threads
)
as
executor
:
# Submit all tasks
futures
=
{}
for
page_num
in
range
(
num_pages
):
page
=
doc
[
page_num
]
future
=
executor
.
submit
(
fitz_doc_to_image
,
page
,
**
kwargs
)
futures
[
future
]
=
page_num
# Process results as they complete with progress bar
for
future
in
as_completed
(
futures
):
page_num
=
futures
[
future
]
try
:
results
[
page_num
]
=
future
.
result
()
except
Exception
as
e
:
print
(
f
'Error processing page
{
page_num
}
:
{
e
}
'
)
results
[
page_num
]
=
None
# Close the document
doc
.
close
()
if
__name__
==
'__main__'
:
pdf
=
fitz
.
open
(
'/tmp/[MS-DOC].pdf'
)
pdf_page
=
[
fitz
.
open
()
for
i
in
range
(
pdf
.
page_count
)]
[
pdf_page
[
i
].
insert_pdf
(
pdf
,
from_page
=
i
,
to_page
=
i
)
for
i
in
range
(
pdf
.
page_count
)]
pdf_page
=
[
v
.
tobytes
()
for
v
in
pdf_page
]
results
=
parallel_process_pdf_safe
(
pdf_page
,
num_workers
=
16
)
# threaded_process_pdf('/tmp/[MS-DOC].pdf', num_threads=16)
""" benchmark results of multi-threaded processing (fitz page to image)
total page nums: 578
thread nums, time cost
1 7.351 sec
2 6.334 sec
4 5.968 sec
8 6.728 sec
16 8.085 sec
"""
""" benchmark results of multi-processor processing (fitz page to image)
total page nums: 578
processor nums, time cost
1 17.170 sec
2 10.170 sec
4 7.841 sec
8 7.900 sec
16 7.984 sec
"""
magic_pdf/dict2md/ocr_mkcontent.py
View file @
41d96cd8
...
@@ -208,12 +208,13 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason
...
@@ -208,12 +208,13 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason
'text'
:
merge_para_with_text
(
para_block
),
'text'
:
merge_para_with_text
(
para_block
),
}
}
elif
para_type
==
BlockType
.
Title
:
elif
para_type
==
BlockType
.
Title
:
title_level
=
get_title_level
(
para_block
)
para_content
=
{
para_content
=
{
'type'
:
'text'
,
'type'
:
'text'
,
'text'
:
merge_para_with_text
(
para_block
),
'text'
:
merge_para_with_text
(
para_block
),
'text_level'
:
title_level
,
}
}
title_level
=
get_title_level
(
para_block
)
if
title_level
!=
0
:
para_content
[
'text_level'
]
=
title_level
elif
para_type
==
BlockType
.
InterlineEquation
:
elif
para_type
==
BlockType
.
InterlineEquation
:
para_content
=
{
para_content
=
{
'type'
:
'equation'
,
'type'
:
'equation'
,
...
@@ -319,5 +320,5 @@ def get_title_level(block):
...
@@ -319,5 +320,5 @@ def get_title_level(block):
if
title_level
>
4
:
if
title_level
>
4
:
title_level
=
4
title_level
=
4
elif
title_level
<
1
:
elif
title_level
<
1
:
title_level
=
1
title_level
=
0
return
title_level
return
title_level
\ No newline at end of file
magic_pdf/libs/pdf_image_tools.py
View file @
41d96cd8
...
@@ -44,14 +44,19 @@ def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
...
@@ -44,14 +44,19 @@ def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
# 截取图片
# 截取图片
pix
=
page
.
get_pixmap
(
clip
=
rect
,
matrix
=
zoom
)
pix
=
page
.
get_pixmap
(
clip
=
rect
,
matrix
=
zoom
)
# 将字节数据转换为文件对象
image_file
=
BytesIO
(
pix
.
tobytes
(
output
=
'png'
))
# 使用 Pillow 打开图像
pil_image
=
Image
.
open
(
image_file
)
if
mode
==
"cv2"
:
if
mode
==
"cv2"
:
image_result
=
cv2
.
cvtColor
(
np
.
asarray
(
pil_image
),
cv2
.
COLOR_RGB2BGR
)
# 直接转换为numpy数组供cv2使用
img_array
=
np
.
frombuffer
(
pix
.
samples
,
dtype
=
np
.
uint8
).
reshape
(
pix
.
height
,
pix
.
width
,
pix
.
n
)
# PyMuPDF使用RGB顺序,而cv2使用BGR顺序
if
pix
.
n
==
3
or
pix
.
n
==
4
:
image_result
=
cv2
.
cvtColor
(
img_array
,
cv2
.
COLOR_RGB2BGR
)
else
:
image_result
=
img_array
elif
mode
==
"pillow"
:
elif
mode
==
"pillow"
:
image_result
=
pil_image
# 将字节数据转换为文件对象
image_file
=
BytesIO
(
pix
.
tobytes
(
output
=
'png'
))
# 使用 Pillow 打开图像
image_result
=
Image
.
open
(
image_file
)
else
:
else
:
raise
ValueError
(
f
"mode:
{
mode
}
is not supported."
)
raise
ValueError
(
f
"mode:
{
mode
}
is not supported."
)
...
...
magic_pdf/libs/performance_stats.py
View file @
41d96cd8
...
@@ -48,7 +48,18 @@ def measure_time(func):
...
@@ -48,7 +48,18 @@ def measure_time(func):
start_time
=
time
.
time
()
start_time
=
time
.
time
()
result
=
func
(
*
args
,
**
kwargs
)
result
=
func
(
*
args
,
**
kwargs
)
execution_time
=
time
.
time
()
-
start_time
execution_time
=
time
.
time
()
-
start_time
PerformanceStats
.
add_execution_time
(
func
.
__name__
,
execution_time
)
# 获取更详细的函数标识
if
hasattr
(
func
,
"__self__"
):
# 实例方法
class_name
=
func
.
__self__
.
__class__
.
__name__
full_name
=
f
"
{
class_name
}
.
{
func
.
__name__
}
"
elif
hasattr
(
func
,
"__qualname__"
):
# 类方法或静态方法
full_name
=
func
.
__qualname__
else
:
module_name
=
func
.
__module__
full_name
=
f
"
{
module_name
}
.
{
func
.
__name__
}
"
PerformanceStats
.
add_execution_time
(
full_name
,
execution_time
)
return
result
return
result
return
wrapper
return
wrapper
\ No newline at end of file
magic_pdf/model/batch_analyze.py
View file @
41d96cd8
This diff is collapsed.
Click to expand it.
magic_pdf/model/doc_analyze_by_custom_model.py
View file @
41d96cd8
import
os
import
os
import
time
import
time
import
numpy
as
np
import
torch
import
torch
os
.
environ
[
'FLAGS_npu_jit_compile'
]
=
'0'
# 关闭paddle的jit编译
os
.
environ
[
'FLAGS_npu_jit_compile'
]
=
'0'
# 关闭paddle的jit编译
os
.
environ
[
'FLAGS_use_stride_kernel'
]
=
'0'
os
.
environ
[
'FLAGS_use_stride_kernel'
]
=
'0'
os
.
environ
[
'PYTORCH_ENABLE_MPS_FALLBACK'
]
=
'1'
# 让mps可以fallback
os
.
environ
[
'PYTORCH_ENABLE_MPS_FALLBACK'
]
=
'1'
# 让mps可以fallback
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
# 关闭paddle的信号处理
import
paddle
paddle
.
disable_signal_handler
()
from
loguru
import
logger
from
loguru
import
logger
from
magic_pdf.model.batch_analyze
import
BatchAnalyze
from
magic_pdf.model.sub_modules.model_utils
import
get_vram
from
magic_pdf.model.sub_modules.model_utils
import
get_vram
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
try
:
import
torchtext
if
torchtext
.
__version__
>=
'0.18.0'
:
torchtext
.
disable_torchtext_deprecation_warning
()
except
ImportError
:
pass
import
magic_pdf.model
as
model_config
import
magic_pdf.model
as
model_config
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.libs.clean_memory
import
clean_memory
...
@@ -30,8 +22,6 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
...
@@ -30,8 +22,6 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_local_models_dir
,
get_local_models_dir
,
get_table_recog_config
)
get_table_recog_config
)
from
magic_pdf.model.model_list
import
MODEL
from
magic_pdf.model.model_list
import
MODEL
from
magic_pdf.operators.models
import
InferenceResult
class
ModelSingleton
:
class
ModelSingleton
:
_instance
=
None
_instance
=
None
...
@@ -72,9 +62,7 @@ def custom_model_init(
...
@@ -72,9 +62,7 @@ def custom_model_init(
formula_enable
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
table_enable
=
None
,
):
):
model
=
None
model
=
None
if
model_config
.
__model_mode__
==
'lite'
:
if
model_config
.
__model_mode__
==
'lite'
:
logger
.
warning
(
logger
.
warning
(
'The Lite mode is provided for developers to conduct testing only, and the output quality is '
'The Lite mode is provided for developers to conduct testing only, and the output quality is '
...
@@ -132,7 +120,6 @@ def custom_model_init(
...
@@ -132,7 +120,6 @@ def custom_model_init(
return
custom_model
return
custom_model
def
doc_analyze
(
def
doc_analyze
(
dataset
:
Dataset
,
dataset
:
Dataset
,
ocr
:
bool
=
False
,
ocr
:
bool
=
False
,
...
@@ -143,102 +130,160 @@ def doc_analyze(
...
@@ -143,102 +130,160 @@ def doc_analyze(
layout_model
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
table_enable
=
None
,
)
->
InferenceResult
:
):
end_page_id
=
(
end_page_id
=
(
end_page_id
end_page_id
if
end_page_id
is
not
None
and
end_page_id
>=
0
if
end_page_id
is
not
None
and
end_page_id
>=
0
else
len
(
dataset
)
-
1
else
len
(
dataset
)
-
1
)
)
model_manager
=
ModelSingleton
()
MIN_BATCH_INFERENCE_SIZE
=
int
(
os
.
environ
.
get
(
'MINERU_MIN_BATCH_INFERENCE_SIZE'
,
200
))
custom_model
=
model_manager
.
get_model
(
images
=
[]
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
page_wh_list
=
[]
)
for
index
in
range
(
len
(
dataset
)):
if
start_page_id
<=
index
<=
end_page_id
:
batch_analyze
=
False
page_data
=
dataset
.
get_page
(
index
)
batch_ratio
=
1
img_dict
=
page_data
.
get_image
()
device
=
get_device
()
images
.
append
(
img_dict
[
'img'
])
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
if
lang
is
None
or
lang
==
'auto'
:
images_with_extra_info
=
[(
images
[
index
],
ocr
,
dataset
.
_lang
)
for
index
in
range
(
len
(
dataset
))]
else
:
images_with_extra_info
=
[(
images
[
index
],
ocr
,
lang
)
for
index
in
range
(
len
(
dataset
))]
npu_support
=
False
if
len
(
images
)
>=
MIN_BATCH_INFERENCE_SIZE
:
if
str
(
device
).
startswith
(
"npu"
):
batch_size
=
MIN_BATCH_INFERENCE_SIZE
import
torch_npu
batch_images
=
[
images_with_extra_info
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images_with_extra_info
),
batch_size
)]
if
torch_npu
.
npu
.
is_available
()
:
else
:
npu_support
=
True
batch_images
=
[
images_with_extra_info
]
if
torch
.
cuda
.
is_available
()
and
device
!=
'cpu'
or
npu_support
:
results
=
[]
gpu_memory
=
int
(
os
.
getenv
(
"VIRTUAL_VRAM_SIZE"
,
round
(
get_vram
(
device
))))
for
sn
,
batch_image
in
enumerate
(
batch_images
):
if
gpu_memory
is
not
None
and
gpu_memory
>=
8
:
_
,
result
=
may_batch_image_analyze
(
batch_image
,
sn
,
ocr
,
show_log
,
layout_model
,
formula_enable
,
table_enable
)
results
.
extend
(
result
)
if
gpu_memory
>=
16
:
model_json
=
[]
batch_ratio
=
8
for
index
in
range
(
len
(
dataset
)):
elif
gpu_memory
>=
10
:
if
start_page_id
<=
index
<=
end_page_id
:
batch_ratio
=
4
result
=
results
.
pop
(
0
)
else
:
page_width
,
page_height
=
page_wh_list
.
pop
(
0
)
batch_ratio
=
2
else
:
result
=
[]
page_height
=
0
page_width
=
0
logger
.
info
(
f
'gpu_memory:
{
gpu_memory
}
GB, batch_ratio:
{
batch_ratio
}
'
)
page_info
=
{
'page_no'
:
index
,
'width'
:
page_width
,
'height'
:
page_height
}
batch_analyze
=
True
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
model_json
.
append
(
page_dict
)
model_json
=
[]
from
magic_pdf.operators.models
import
InferenceResult
doc_analyze_start
=
time
.
time
(
)
return
InferenceResult
(
model_json
,
dataset
)
if
batch_analyze
:
def
batch_doc_analyze
(
# batch analyze
datasets
:
list
[
Dataset
],
images
=
[]
parse_method
:
str
,
page_wh_list
=
[]
show_log
:
bool
=
False
,
for
index
in
range
(
len
(
dataset
)):
lang
=
None
,
if
start_page_id
<=
index
<=
end_page_id
:
layout_model
=
None
,
page_data
=
dataset
.
get_page
(
index
)
formula_enable
=
None
,
img_dict
=
page_data
.
get_image
()
table_enable
=
None
,
images
.
append
(
img_dict
[
'img'
])
):
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
MIN_BATCH_INFERENCE_SIZE
=
int
(
os
.
environ
.
get
(
'MINERU_MIN_BATCH_INFERENCE_SIZE'
,
200
))
batch_model
=
BatchAnalyze
(
model
=
custom_model
,
batch_ratio
=
batch_ratio
)
batch_size
=
MIN_BATCH_INFERENCE_SIZE
analyze_result
=
batch_model
(
images
)
images
=
[]
page_wh_list
=
[]
images_with_extra_info
=
[]
for
dataset
in
datasets
:
for
index
in
range
(
len
(
dataset
)):
for
index
in
range
(
len
(
dataset
)):
if
start_page_id
<=
index
<=
end_page_id
:
if
lang
is
None
or
lang
==
'auto'
:
result
=
analyze_result
.
pop
(
0
)
_lang
=
dataset
.
_lang
page_width
,
page_height
=
page_wh_list
.
pop
(
0
)
else
:
else
:
result
=
[]
_lang
=
lang
page_height
=
0
page_width
=
0
page_info
=
{
'page_no'
:
index
,
'width'
:
page_width
,
'height'
:
page_height
}
page_data
=
dataset
.
get_page
(
index
)
img_dict
=
page_data
.
get_image
()
images
.
append
(
img_dict
[
'img'
])
page_wh_list
.
append
((
img_dict
[
'width'
],
img_dict
[
'height'
]))
if
parse_method
==
'auto'
:
images_with_extra_info
.
append
((
images
[
-
1
],
dataset
.
classify
()
==
SupportedPdfParseMethod
.
OCR
,
_lang
))
else
:
images_with_extra_info
.
append
((
images
[
-
1
],
parse_method
==
'ocr'
,
_lang
))
batch_images
=
[
images_with_extra_info
[
i
:
i
+
batch_size
]
for
i
in
range
(
0
,
len
(
images_with_extra_info
),
batch_size
)]
results
=
[]
for
sn
,
batch_image
in
enumerate
(
batch_images
):
_
,
result
=
may_batch_image_analyze
(
batch_image
,
sn
,
True
,
show_log
,
layout_model
,
formula_enable
,
table_enable
)
results
.
extend
(
result
)
infer_results
=
[]
from
magic_pdf.operators.models
import
InferenceResult
for
index
in
range
(
len
(
datasets
)):
dataset
=
datasets
[
index
]
model_json
=
[]
for
i
in
range
(
len
(
dataset
)):
result
=
results
.
pop
(
0
)
page_width
,
page_height
=
page_wh_list
.
pop
(
0
)
page_info
=
{
'page_no'
:
i
,
'width'
:
page_width
,
'height'
:
page_height
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
model_json
.
append
(
page_dict
)
model_json
.
append
(
page_dict
)
infer_results
.
append
(
InferenceResult
(
model_json
,
dataset
))
return
infer_results
else
:
# single analyze
for
index
in
range
(
len
(
dataset
)):
def
may_batch_image_analyze
(
page_data
=
dataset
.
get_page
(
index
)
images_with_extra_info
:
list
[(
np
.
ndarray
,
bool
,
str
)],
img_dict
=
page_data
.
get_image
()
idx
:
int
,
img
=
img_dict
[
'img'
]
ocr
:
bool
,
page_width
=
img_dict
[
'width'
]
show_log
:
bool
=
False
,
page_height
=
img_dict
[
'height'
]
layout_model
=
None
,
if
start_page_id
<=
index
<=
end_page_id
:
formula_enable
=
None
,
page_start
=
time
.
time
()
table_enable
=
None
):
result
=
custom_model
(
img
)
# os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
logger
.
info
(
f
'-----page_id :
{
index
}
, page total time:
{
round
(
time
.
time
()
-
page_start
,
2
)
}
-----'
)
from
magic_pdf.model.batch_analyze
import
BatchAnalyze
model_manager
=
ModelSingleton
()
# images = [image for image, _, _ in images_with_extra_info]
batch_ratio
=
1
device
=
get_device
()
if
str
(
device
).
startswith
(
'npu'
):
import
torch_npu
if
torch_npu
.
npu
.
is_available
():
torch
.
npu
.
set_compile_mode
(
jit_compile
=
False
)
if
str
(
device
).
startswith
(
'npu'
)
or
str
(
device
).
startswith
(
'cuda'
):
gpu_memory
=
int
(
os
.
getenv
(
'VIRTUAL_VRAM_SIZE'
,
round
(
get_vram
(
device
))))
if
gpu_memory
is
not
None
:
if
gpu_memory
>=
16
:
batch_ratio
=
16
elif
gpu_memory
>=
12
:
batch_ratio
=
8
elif
gpu_memory
>=
8
:
batch_ratio
=
4
elif
gpu_memory
>=
6
:
batch_ratio
=
2
else
:
else
:
result
=
[]
batch_ratio
=
1
logger
.
info
(
f
'gpu_memory:
{
gpu_memory
}
GB, batch_ratio:
{
batch_ratio
}
'
)
page_info
=
{
'page_no'
:
index
,
'width'
:
page_width
,
'height'
:
page_height
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info
}
model_json
.
append
(
page_dict
)
gc_start
=
time
.
time
()
# doc_analyze_start = time.time()
clean_memory
(
get_device
())
gc_time
=
round
(
time
.
time
()
-
gc_start
,
2
)
logger
.
info
(
f
'gc time:
{
gc_time
}
'
)
doc_analyze_time
=
round
(
time
.
time
()
-
doc_analyze_start
,
2
)
doc_analyze_speed
=
round
((
end_page_id
+
1
-
start_page_id
)
/
doc_analyze_time
,
2
)
logger
.
info
(
f
'doc analyze time:
{
round
(
time
.
time
()
-
doc_analyze_start
,
2
)
}
,'
f
' speed:
{
doc_analyze_speed
}
pages/second'
)
return
InferenceResult
(
model_json
,
dataset
)
batch_model
=
BatchAnalyze
(
model_manager
,
batch_ratio
,
show_log
,
layout_model
,
formula_enable
,
table_enable
)
results
=
batch_model
(
images_with_extra_info
)
# gc_start = time.time()
clean_memory
(
get_device
())
# gc_time = round(time.time() - gc_start, 2)
# logger.debug(f'gc time: {gc_time}')
# doc_analyze_time = round(time.time() - doc_analyze_start, 2)
# doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
# logger.debug(
# f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
# f' speed: {doc_analyze_speed} pages/second'
# )
return
idx
,
results
magic_pdf/model/pdf_extract_kit.py
View file @
41d96cd8
...
@@ -3,28 +3,18 @@ import os
...
@@ -3,28 +3,18 @@ import os
import
time
import
time
import
cv2
import
cv2
import
numpy
as
np
import
torch
import
torch
import
yaml
import
yaml
from
loguru
import
logger
from
loguru
import
logger
from
PIL
import
Image
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
try
:
import
torchtext
if
torchtext
.
__version__
>=
'0.18.0'
:
torchtext
.
disable_torchtext_deprecation_warning
()
except
ImportError
:
pass
from
magic_pdf.config.constants
import
*
from
magic_pdf.config.constants
import
*
from
magic_pdf.model.model_list
import
AtomicModel
from
magic_pdf.model.model_list
import
AtomicModel
from
magic_pdf.model.sub_modules.model_init
import
AtomModelSingleton
from
magic_pdf.model.sub_modules.model_init
import
AtomModelSingleton
from
magic_pdf.model.sub_modules.model_utils
import
(
from
magic_pdf.model.sub_modules.model_utils
import
(
clean_vram
,
crop_img
,
get_res_list_from_layout_res
)
clean_vram
,
crop_img
,
get_res_list_from_layout_res
)
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
from
magic_pdf.model.sub_modules.ocr.paddleocr
2pytorch
.ocr_utils
import
(
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
...
@@ -120,7 +110,7 @@ class CustomPEKModel:
...
@@ -120,7 +110,7 @@ class CustomPEKModel:
atom_model_name
=
AtomicModel
.
MFR
,
atom_model_name
=
AtomicModel
.
MFR
,
mfr_weight_dir
=
mfr_weight_dir
,
mfr_weight_dir
=
mfr_weight_dir
,
mfr_cfg_path
=
mfr_cfg_path
,
mfr_cfg_path
=
mfr_cfg_path
,
device
=
'cpu'
if
str
(
self
.
device
).
startswith
(
"mps"
)
else
self
.
device
,
device
=
self
.
device
,
)
)
# 初始化layout模型
# 初始化layout模型
...
@@ -174,11 +164,6 @@ class CustomPEKModel:
...
@@ -174,11 +164,6 @@ class CustomPEKModel:
logger
.
info
(
'DocAnalysis init done!'
)
logger
.
info
(
'DocAnalysis init done!'
)
def
__call__
(
self
,
image
):
def
__call__
(
self
,
image
):
pil_img
=
Image
.
fromarray
(
image
)
width
,
height
=
pil_img
.
size
# logger.info(f'width: {width}, height: {height}')
# layout检测
# layout检测
layout_start
=
time
.
time
()
layout_start
=
time
.
time
()
layout_res
=
[]
layout_res
=
[]
...
@@ -186,24 +171,6 @@ class CustomPEKModel:
...
@@ -186,24 +171,6 @@ class CustomPEKModel:
# layoutlmv3
# layoutlmv3
layout_res
=
self
.
layout_model
(
image
,
ignore_catids
=
[])
layout_res
=
self
.
layout_model
(
image
,
ignore_catids
=
[])
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
# doclayout_yolo
# if height > width:
# input_res = {"poly":[0,0,width,0,width,height,0,height]}
# new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
# paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# layout_res = self.layout_model.predict(new_image)
# for res in layout_res:
# p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
# p1 = p1 - paste_x + xmin
# p2 = p2 - paste_y + ymin
# p3 = p3 - paste_x + xmin
# p4 = p4 - paste_y + ymin
# p5 = p5 - paste_x + xmin
# p6 = p6 - paste_y + ymin
# p7 = p7 - paste_x + xmin
# p8 = p8 - paste_y + ymin
# res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
# else:
layout_res
=
self
.
layout_model
.
predict
(
image
)
layout_res
=
self
.
layout_model
.
predict
(
image
)
layout_cost
=
round
(
time
.
time
()
-
layout_start
,
2
)
layout_cost
=
round
(
time
.
time
()
-
layout_start
,
2
)
...
@@ -234,11 +201,11 @@ class CustomPEKModel:
...
@@ -234,11 +201,11 @@ class CustomPEKModel:
ocr_start
=
time
.
time
()
ocr_start
=
time
.
time
()
# Process each area that requires OCR processing
# Process each area that requires OCR processing
for
res
in
ocr_res_list
:
for
res
in
ocr_res_list
:
new_image
,
useful_list
=
crop_img
(
res
,
pil_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
new_image
,
useful_list
=
crop_img
(
res
,
image
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
# OCR recognition
# OCR recognition
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
)
,
cv2
.
COLOR_RGB2BGR
)
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
if
self
.
apply_ocr
:
if
self
.
apply_ocr
:
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
...
@@ -260,7 +227,7 @@ class CustomPEKModel:
...
@@ -260,7 +227,7 @@ class CustomPEKModel:
if
self
.
apply_table
:
if
self
.
apply_table
:
table_start
=
time
.
time
()
table_start
=
time
.
time
()
for
res
in
table_res_list
:
for
res
in
table_res_list
:
new_image
,
_
=
crop_img
(
res
,
pil_img
)
new_image
,
_
=
crop_img
(
res
,
image
)
single_table_start_time
=
time
.
time
()
single_table_start_time
=
time
.
time
()
html_code
=
None
html_code
=
None
if
self
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
if
self
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
...
...
magic_pdf/model/sub_modules/language_detection/utils.py
View file @
41d96cd8
...
@@ -3,8 +3,6 @@ import os
...
@@ -3,8 +3,6 @@ import os
from
pathlib
import
Path
from
pathlib
import
Path
import
yaml
import
yaml
from
PIL
import
Image
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
from
magic_pdf.config.constants
import
MODEL_NAME
from
magic_pdf.config.constants
import
MODEL_NAME
...
@@ -42,7 +40,7 @@ def get_text_images(simple_images):
...
@@ -42,7 +40,7 @@ def get_text_images(simple_images):
)
)
text_images
=
[]
text_images
=
[]
for
simple_image
in
simple_images
:
for
simple_image
in
simple_images
:
image
=
Image
.
fromarray
(
simple_image
[
'img'
]
)
image
=
simple_image
[
'img'
]
layout_res
=
temp_layout_model
.
predict
(
image
)
layout_res
=
temp_layout_model
.
predict
(
image
)
# 给textblock截图
# 给textblock截图
for
res
in
layout_res
:
for
res
in
layout_res
:
...
@@ -51,7 +49,7 @@ def get_text_images(simple_images):
...
@@ -51,7 +49,7 @@ def get_text_images(simple_images):
# 初步清洗(宽和高都小于100)
# 初步清洗(宽和高都小于100)
if
x2
-
x1
<
100
and
y2
-
y1
<
100
:
if
x2
-
x1
<
100
and
y2
-
y1
<
100
:
continue
continue
text_images
.
append
(
image
.
crop
((
x1
,
y1
,
x2
,
y2
))
)
text_images
.
append
(
image
[
y1
:
y2
,
x1
:
x2
]
)
return
text_images
return
text_images
...
...
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py
View file @
41d96cd8
...
@@ -2,9 +2,9 @@
...
@@ -2,9 +2,9 @@
import
time
import
time
from
collections
import
Counter
from
collections
import
Counter
from
uuid
import
uuid4
from
uuid
import
uuid4
import
cv2
import
numpy
as
np
import
torch
import
torch
from
PIL
import
Image
from
loguru
import
logger
from
loguru
import
logger
from
ultralytics
import
YOLO
from
ultralytics
import
YOLO
...
@@ -29,7 +29,7 @@ def split_images(image, result_images=None):
...
@@ -29,7 +29,7 @@ def split_images(image, result_images=None):
if
result_images
is
None
:
if
result_images
is
None
:
result_images
=
[]
result_images
=
[]
width
,
height
=
image
.
s
ize
height
,
width
=
image
.
s
hape
[:
2
]
long_side
=
max
(
width
,
height
)
# 获取较长边长度
long_side
=
max
(
width
,
height
)
# 获取较长边长度
if
long_side
<=
400
:
if
long_side
<=
400
:
...
@@ -44,16 +44,14 @@ def split_images(image, result_images=None):
...
@@ -44,16 +44,14 @@ def split_images(image, result_images=None):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if
x
+
new_long_side
>
width
:
if
x
+
new_long_side
>
width
:
continue
continue
box
=
(
x
,
0
,
x
+
new_long_side
,
height
)
sub_image
=
image
[
0
:
height
,
x
:
x
+
new_long_side
]
sub_image
=
image
.
crop
(
box
)
sub_images
.
append
(
sub_image
)
sub_images
.
append
(
sub_image
)
else
:
# 如果高度是较长边
else
:
# 如果高度是较长边
for
y
in
range
(
0
,
height
,
new_long_side
):
for
y
in
range
(
0
,
height
,
new_long_side
):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if
y
+
new_long_side
>
height
:
if
y
+
new_long_side
>
height
:
continue
continue
box
=
(
0
,
y
,
width
,
y
+
new_long_side
)
sub_image
=
image
[
y
:
y
+
new_long_side
,
0
:
width
]
sub_image
=
image
.
crop
(
box
)
sub_images
.
append
(
sub_image
)
sub_images
.
append
(
sub_image
)
for
sub_image
in
sub_images
:
for
sub_image
in
sub_images
:
...
@@ -64,24 +62,32 @@ def split_images(image, result_images=None):
...
@@ -64,24 +62,32 @@ def split_images(image, result_images=None):
def
resize_images_to_224
(
image
):
def
resize_images_to_224
(
image
):
"""
"""
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小。
Works directly with NumPy arrays.
"""
"""
try
:
try
:
width
,
height
=
image
.
size
height
,
width
=
image
.
shape
[:
2
]
if
width
<
224
or
height
<
224
:
if
width
<
224
or
height
<
224
:
new_image
=
Image
.
new
(
'RGB'
,
(
224
,
224
),
(
0
,
0
,
0
))
# Create black background
paste_x
=
(
224
-
width
)
//
2
new_image
=
np
.
zeros
((
224
,
224
,
3
),
dtype
=
np
.
uint8
)
paste_y
=
(
224
-
height
)
//
2
# Calculate paste position (ensure they're not negative)
new_image
.
paste
(
image
,
(
paste_x
,
paste_y
))
paste_x
=
max
(
0
,
(
224
-
width
)
//
2
)
paste_y
=
max
(
0
,
(
224
-
height
)
//
2
)
# Make sure we don't exceed the boundaries of new_image
paste_width
=
min
(
width
,
224
)
paste_height
=
min
(
height
,
224
)
# Paste original image onto black background
new_image
[
paste_y
:
paste_y
+
paste_height
,
paste_x
:
paste_x
+
paste_width
]
=
image
[:
paste_height
,
:
paste_width
]
image
=
new_image
image
=
new_image
else
:
else
:
image
=
image
.
resize
((
224
,
224
),
Image
.
Resampling
.
LANCZOS
)
# Resize using cv2
image
=
cv2
.
resize
(
image
,
(
224
,
224
),
interpolation
=
cv2
.
INTER_LANCZOS4
)
# uuid = str(uuid4())
# image.save(f"/tmp/{uuid}.jpg")
return
image
return
image
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
exception
(
e
)
logger
.
exception
(
f
"Error in resize_images_to_224:
{
e
}
"
)
return
None
class
YOLOv11LangDetModel
(
object
):
class
YOLOv11LangDetModel
(
object
):
...
@@ -96,8 +102,7 @@ class YOLOv11LangDetModel(object):
...
@@ -96,8 +102,7 @@ class YOLOv11LangDetModel(object):
def
do_detect
(
self
,
images
:
list
):
def
do_detect
(
self
,
images
:
list
):
all_images
=
[]
all_images
=
[]
for
image
in
images
:
for
image
in
images
:
width
,
height
=
image
.
size
height
,
width
=
image
.
shape
[:
2
]
# logger.info(f"image size: {width} x {height}")
if
width
<
100
and
height
<
100
:
if
width
<
100
and
height
<
100
:
continue
continue
temp_images
=
split_images
(
image
)
temp_images
=
split_images
(
image
)
...
...
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py
View file @
41d96cd8
from
doclayout_yolo
import
YOLOv10
from
doclayout_yolo
import
YOLOv10
from
tqdm
import
tqdm
class
DocLayoutYOLOModel
(
object
):
class
DocLayoutYOLOModel
(
object
):
...
@@ -31,7 +32,8 @@ class DocLayoutYOLOModel(object):
...
@@ -31,7 +32,8 @@ class DocLayoutYOLOModel(object):
def
batch_predict
(
self
,
images
:
list
,
batch_size
:
int
)
->
list
:
def
batch_predict
(
self
,
images
:
list
,
batch_size
:
int
)
->
list
:
images_layout_res
=
[]
images_layout_res
=
[]
for
index
in
range
(
0
,
len
(
images
),
batch_size
):
# for index in range(0, len(images), batch_size):
for
index
in
tqdm
(
range
(
0
,
len
(
images
),
batch_size
),
desc
=
"Layout Predict"
):
doclayout_yolo_res
=
[
doclayout_yolo_res
=
[
image_res
.
cpu
()
image_res
.
cpu
()
for
image_res
in
self
.
model
.
predict
(
for
image_res
in
self
.
model
.
predict
(
...
...
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
View file @
41d96cd8
from
tqdm
import
tqdm
from
ultralytics
import
YOLO
from
ultralytics
import
YOLO
...
@@ -14,7 +15,8 @@ class YOLOv8MFDModel(object):
...
@@ -14,7 +15,8 @@ class YOLOv8MFDModel(object):
def
batch_predict
(
self
,
images
:
list
,
batch_size
:
int
)
->
list
:
def
batch_predict
(
self
,
images
:
list
,
batch_size
:
int
)
->
list
:
images_mfd_res
=
[]
images_mfd_res
=
[]
for
index
in
range
(
0
,
len
(
images
),
batch_size
):
# for index in range(0, len(images), batch_size):
for
index
in
tqdm
(
range
(
0
,
len
(
images
),
batch_size
),
desc
=
"MFD Predict"
):
mfd_res
=
[
mfd_res
=
[
image_res
.
cpu
()
image_res
.
cpu
()
for
image_res
in
self
.
mfd_model
.
predict
(
for
image_res
in
self
.
mfd_model
.
predict
(
...
...
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
View file @
41d96cd8
import
argparse
import
os
import
re
import
torch
import
torch
import
unimernet.tasks
as
tasks
from
PIL
import
Image
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data
import
DataLoader
,
Dataset
from
torchvision
import
transforms
from
tqdm
import
tqdm
from
unimernet.common.config
import
Config
from
unimernet.processors
import
load_processor
class
MathDataset
(
Dataset
):
class
MathDataset
(
Dataset
):
...
@@ -20,55 +12,24 @@ class MathDataset(Dataset):
...
@@ -20,55 +12,24 @@ class MathDataset(Dataset):
return
len
(
self
.
image_paths
)
return
len
(
self
.
image_paths
)
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
# if not pil image, then convert to pil image
raw_image
=
self
.
image_paths
[
idx
]
if
isinstance
(
self
.
image_paths
[
idx
],
str
):
raw_image
=
Image
.
open
(
self
.
image_paths
[
idx
])
else
:
raw_image
=
self
.
image_paths
[
idx
]
if
self
.
transform
:
if
self
.
transform
:
image
=
self
.
transform
(
raw_image
)
image
=
self
.
transform
(
raw_image
)
return
image
return
image
def
latex_rm_whitespace
(
s
:
str
):
"""Remove unnecessary whitespace from LaTeX code."""
text_reg
=
r
"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
letter
=
"[a-zA-Z]"
noletter
=
"[\W_^\d]"
names
=
[
x
[
0
].
replace
(
" "
,
""
)
for
x
in
re
.
findall
(
text_reg
,
s
)]
s
=
re
.
sub
(
text_reg
,
lambda
match
:
str
(
names
.
pop
(
0
)),
s
)
news
=
s
while
True
:
s
=
news
news
=
re
.
sub
(
r
"(?!\\ )(%s)\s+?(%s)"
%
(
noletter
,
noletter
),
r
"\1\2"
,
s
)
news
=
re
.
sub
(
r
"(?!\\ )(%s)\s+?(%s)"
%
(
noletter
,
letter
),
r
"\1\2"
,
news
)
news
=
re
.
sub
(
r
"(%s)\s+?(%s)"
%
(
letter
,
noletter
),
r
"\1\2"
,
news
)
if
news
==
s
:
break
return
s
class
UnimernetModel
(
object
):
class
UnimernetModel
(
object
):
def
__init__
(
self
,
weight_dir
,
cfg_path
,
_device_
=
"cpu"
):
def
__init__
(
self
,
weight_dir
,
cfg_path
,
_device_
=
"cpu"
):
args
=
argparse
.
Namespace
(
cfg_path
=
cfg_path
,
options
=
None
)
from
.unimernet_hf
import
UnimernetModel
cfg
=
Config
(
args
)
if
_device_
.
startswith
(
"mps"
):
cfg
.
config
.
model
.
pretrained
=
os
.
path
.
join
(
weight_dir
,
"pytorch_model.pth"
)
self
.
model
=
UnimernetModel
.
from_pretrained
(
weight_dir
,
attn_implementation
=
"eager"
)
cfg
.
config
.
model
.
model_config
.
model_name
=
weight_dir
else
:
cfg
.
config
.
model
.
tokenizer_config
.
path
=
weight_dir
self
.
model
=
UnimernetModel
.
from_pretrained
(
weight_dir
)
task
=
tasks
.
setup_task
(
cfg
)
self
.
model
=
task
.
build_model
(
cfg
)
self
.
device
=
_device_
self
.
device
=
_device_
self
.
model
.
to
(
_device_
)
self
.
model
.
to
(
_device_
)
if
not
_device_
.
startswith
(
"cpu"
):
self
.
model
=
self
.
model
.
to
(
dtype
=
torch
.
float16
)
self
.
model
.
eval
()
self
.
model
.
eval
()
vis_processor
=
load_processor
(
"formula_image_eval"
,
cfg
.
config
.
datasets
.
formula_rec_eval
.
vis_processor
.
eval
,
)
self
.
mfr_transform
=
transforms
.
Compose
(
[
vis_processor
,
]
)
def
predict
(
self
,
mfd_res
,
image
):
def
predict
(
self
,
mfd_res
,
image
):
formula_list
=
[]
formula_list
=
[]
...
@@ -84,62 +45,22 @@ class UnimernetModel(object):
...
@@ -84,62 +45,22 @@ class UnimernetModel(object):
"latex"
:
""
,
"latex"
:
""
,
}
}
formula_list
.
append
(
new_item
)
formula_list
.
append
(
new_item
)
pil_img
=
Image
.
fromarray
(
image
)
bbox_img
=
image
[
ymin
:
ymax
,
xmin
:
xmax
]
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
ymax
))
mf_image_list
.
append
(
bbox_img
)
mf_image_list
.
append
(
bbox_img
)
dataset
=
MathDataset
(
mf_image_list
,
transform
=
self
.
m
fr_
transform
)
dataset
=
MathDataset
(
mf_image_list
,
transform
=
self
.
m
odel
.
transform
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
32
,
num_workers
=
0
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
32
,
num_workers
=
0
)
mfr_res
=
[]
mfr_res
=
[]
for
mf_img
in
dataloader
:
for
mf_img
in
dataloader
:
mf_img
=
mf_img
.
to
(
dtype
=
self
.
model
.
dtype
)
mf_img
=
mf_img
.
to
(
self
.
device
)
mf_img
=
mf_img
.
to
(
self
.
device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
self
.
model
.
generate
({
"image"
:
mf_img
})
output
=
self
.
model
.
generate
({
"image"
:
mf_img
})
mfr_res
.
extend
(
output
[
"
pr
ed_str"
])
mfr_res
.
extend
(
output
[
"
fix
ed_str"
])
for
res
,
latex
in
zip
(
formula_list
,
mfr_res
):
for
res
,
latex
in
zip
(
formula_list
,
mfr_res
):
res
[
"latex"
]
=
latex
_rm_whitespace
(
latex
)
res
[
"latex"
]
=
latex
return
formula_list
return
formula_list
# def batch_predict(
# self, images_mfd_res: list, images: list, batch_size: int = 64
# ) -> list:
# images_formula_list = []
# mf_image_list = []
# backfill_list = []
# for image_index in range(len(images_mfd_res)):
# mfd_res = images_mfd_res[image_index]
# pil_img = Image.fromarray(images[image_index])
# formula_list = []
#
# for xyxy, conf, cla in zip(
# mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
# ):
# xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
# new_item = {
# "category_id": 13 + int(cla.item()),
# "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
# "score": round(float(conf.item()), 2),
# "latex": "",
# }
# formula_list.append(new_item)
# bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
# mf_image_list.append(bbox_img)
#
# images_formula_list.append(formula_list)
# backfill_list += formula_list
#
# dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# mfr_res = []
# for mf_img in dataloader:
# mf_img = mf_img.to(self.device)
# with torch.no_grad():
# output = self.model.generate({"image": mf_img})
# mfr_res.extend(output["pred_str"])
# for res, latex in zip(backfill_list, mfr_res):
# res["latex"] = latex_rm_whitespace(latex)
# return images_formula_list
def
batch_predict
(
self
,
images_mfd_res
:
list
,
images
:
list
,
batch_size
:
int
=
64
)
->
list
:
def
batch_predict
(
self
,
images_mfd_res
:
list
,
images
:
list
,
batch_size
:
int
=
64
)
->
list
:
images_formula_list
=
[]
images_formula_list
=
[]
mf_image_list
=
[]
mf_image_list
=
[]
...
@@ -149,7 +70,7 @@ class UnimernetModel(object):
...
@@ -149,7 +70,7 @@ class UnimernetModel(object):
# Collect images with their original indices
# Collect images with their original indices
for
image_index
in
range
(
len
(
images_mfd_res
)):
for
image_index
in
range
(
len
(
images_mfd_res
)):
mfd_res
=
images_mfd_res
[
image_index
]
mfd_res
=
images_mfd_res
[
image_index
]
pil_img
=
Image
.
fromarray
(
images
[
image_index
]
)
np_array_image
=
images
[
image_index
]
formula_list
=
[]
formula_list
=
[]
for
idx
,
(
xyxy
,
conf
,
cla
)
in
enumerate
(
zip
(
for
idx
,
(
xyxy
,
conf
,
cla
)
in
enumerate
(
zip
(
...
@@ -163,7 +84,7 @@ class UnimernetModel(object):
...
@@ -163,7 +84,7 @@ class UnimernetModel(object):
"latex"
:
""
,
"latex"
:
""
,
}
}
formula_list
.
append
(
new_item
)
formula_list
.
append
(
new_item
)
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
y
max
))
bbox_img
=
np_array_image
[
ymin
:
ymax
,
xmin
:
x
max
]
area
=
(
xmax
-
xmin
)
*
(
ymax
-
ymin
)
area
=
(
xmax
-
xmin
)
*
(
ymax
-
ymin
)
curr_idx
=
len
(
mf_image_list
)
curr_idx
=
len
(
mf_image_list
)
...
@@ -182,22 +103,30 @@ class UnimernetModel(object):
...
@@ -182,22 +103,30 @@ class UnimernetModel(object):
index_mapping
=
{
new_idx
:
old_idx
for
new_idx
,
old_idx
in
enumerate
(
sorted_indices
)}
index_mapping
=
{
new_idx
:
old_idx
for
new_idx
,
old_idx
in
enumerate
(
sorted_indices
)}
# Create dataset with sorted images
# Create dataset with sorted images
dataset
=
MathDataset
(
sorted_images
,
transform
=
self
.
m
fr_
transform
)
dataset
=
MathDataset
(
sorted_images
,
transform
=
self
.
m
odel
.
transform
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
num_workers
=
0
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
num_workers
=
0
)
# Process batches and store results
# Process batches and store results
mfr_res
=
[]
mfr_res
=
[]
for
mf_img
in
dataloader
:
# for mf_img in dataloader:
mf_img
=
mf_img
.
to
(
self
.
device
)
with
torch
.
no_grad
():
with
tqdm
(
total
=
len
(
sorted_images
),
desc
=
"MFR Predict"
)
as
pbar
:
output
=
self
.
model
.
generate
({
"image"
:
mf_img
})
for
index
,
mf_img
in
enumerate
(
dataloader
):
mfr_res
.
extend
(
output
[
"pred_str"
])
mf_img
=
mf_img
.
to
(
dtype
=
self
.
model
.
dtype
)
mf_img
=
mf_img
.
to
(
self
.
device
)
with
torch
.
no_grad
():
output
=
self
.
model
.
generate
({
"image"
:
mf_img
})
mfr_res
.
extend
(
output
[
"fixed_str"
])
# 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
current_batch_size
=
min
(
batch_size
,
len
(
sorted_images
)
-
index
*
batch_size
)
pbar
.
update
(
current_batch_size
)
# Restore original order
# Restore original order
unsorted_results
=
[
""
]
*
len
(
mfr_res
)
unsorted_results
=
[
""
]
*
len
(
mfr_res
)
for
new_idx
,
latex
in
enumerate
(
mfr_res
):
for
new_idx
,
latex
in
enumerate
(
mfr_res
):
original_idx
=
index_mapping
[
new_idx
]
original_idx
=
index_mapping
[
new_idx
]
unsorted_results
[
original_idx
]
=
latex
_rm_whitespace
(
latex
)
unsorted_results
[
original_idx
]
=
latex
# Fill results back
# Fill results back
for
res
,
latex
in
zip
(
backfill_list
,
unsorted_results
):
for
res
,
latex
in
zip
(
backfill_list
,
unsorted_results
):
...
...
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py
0 → 100644
View file @
41d96cd8
from
.unimer_swin
import
UnimerSwinConfig
,
UnimerSwinModel
,
UnimerSwinImageProcessor
from
.unimer_mbart
import
UnimerMBartConfig
,
UnimerMBartModel
,
UnimerMBartForCausalLM
from
.modeling_unimernet
import
UnimernetModel
__all__
=
[
"UnimerSwinConfig"
,
"UnimerSwinModel"
,
"UnimerSwinImageProcessor"
,
"UnimerMBartConfig"
,
"UnimerMBartModel"
,
"UnimerMBartForCausalLM"
,
"UnimernetModel"
,
]
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py
0 → 100644
View file @
41d96cd8
import
os
import
re
import
warnings
from
typing
import
Optional
import
torch
from
ftfy
import
fix_text
from
transformers
import
AutoConfig
,
AutoModel
,
AutoModelForCausalLM
,
AutoTokenizer
,
PretrainedConfig
,
PreTrainedModel
from
transformers
import
VisionEncoderDecoderConfig
,
VisionEncoderDecoderModel
from
transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder
import
logger
as
base_model_logger
from
.unimer_swin
import
UnimerSwinConfig
,
UnimerSwinModel
,
UnimerSwinImageProcessor
from
.unimer_mbart
import
UnimerMBartConfig
,
UnimerMBartForCausalLM
AutoConfig
.
register
(
UnimerSwinConfig
.
model_type
,
UnimerSwinConfig
)
AutoConfig
.
register
(
UnimerMBartConfig
.
model_type
,
UnimerMBartConfig
)
AutoModel
.
register
(
UnimerSwinConfig
,
UnimerSwinModel
)
AutoModelForCausalLM
.
register
(
UnimerMBartConfig
,
UnimerMBartForCausalLM
)
# TODO: rewrite tokenizer
class
TokenizerWrapper
:
def
__init__
(
self
,
tokenizer
):
self
.
tokenizer
=
tokenizer
self
.
pad_token_id
=
self
.
tokenizer
.
pad_token_id
self
.
bos_token_id
=
self
.
tokenizer
.
bos_token_id
self
.
eos_token_id
=
self
.
tokenizer
.
eos_token_id
def
__len__
(
self
):
return
len
(
self
.
tokenizer
)
def
tokenize
(
self
,
text
,
**
kwargs
):
return
self
.
tokenizer
(
text
,
return_token_type_ids
=
False
,
return_tensors
=
"pt"
,
padding
=
"longest"
,
truncation
=
True
,
**
kwargs
,
)
def
token2str
(
self
,
tokens
)
->
list
:
generated_text
=
self
.
tokenizer
.
batch_decode
(
tokens
,
skip_special_tokens
=
True
)
generated_text
=
[
fix_text
(
text
)
for
text
in
generated_text
]
return
generated_text
def
detokenize
(
self
,
tokens
):
toks
=
[
self
.
tokenizer
.
convert_ids_to_tokens
(
tok
)
for
tok
in
tokens
]
for
b
in
range
(
len
(
toks
)):
for
i
in
reversed
(
range
(
len
(
toks
[
b
]))):
if
toks
[
b
][
i
]
is
None
:
toks
[
b
][
i
]
=
''
toks
[
b
][
i
]
=
toks
[
b
][
i
].
replace
(
'Ġ'
,
' '
).
strip
()
if
toks
[
b
][
i
]
in
([
self
.
tokenizer
.
bos_token
,
self
.
tokenizer
.
eos_token
,
self
.
tokenizer
.
pad_token
]):
del
toks
[
b
][
i
]
return
toks
def
latex_rm_whitespace
(
s
:
str
):
"""Remove unnecessary whitespace from LaTeX code.
"""
text_reg
=
r
'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
letter
=
r
'[a-zA-Z]'
noletter
=
r
'[\W_^\d]'
names
=
[
x
[
0
].
replace
(
' '
,
''
)
for
x
in
re
.
findall
(
text_reg
,
s
)]
s
=
re
.
sub
(
text_reg
,
lambda
_
:
str
(
names
.
pop
(
0
)),
s
)
news
=
s
while
True
:
s
=
news
news
=
re
.
sub
(
r
'(?!\\ )(%s)\s+?(%s)'
%
(
noletter
,
noletter
),
r
'\1\2'
,
s
)
news
=
re
.
sub
(
r
'(?!\\ )(%s)\s+?(%s)'
%
(
noletter
,
letter
),
r
'\1\2'
,
news
)
news
=
re
.
sub
(
r
'(%s)\s+?(%s)'
%
(
letter
,
noletter
),
r
'\1\2'
,
news
)
if
news
==
s
:
break
return
s
class
UnimernetModel
(
VisionEncoderDecoderModel
):
def
__init__
(
self
,
config
:
Optional
[
PretrainedConfig
]
=
None
,
encoder
:
Optional
[
PreTrainedModel
]
=
None
,
decoder
:
Optional
[
PreTrainedModel
]
=
None
,
):
# VisionEncoderDecoderModel's checking log has bug, disable for temp.
base_model_logger
.
disabled
=
True
try
:
super
().
__init__
(
config
,
encoder
,
decoder
)
finally
:
base_model_logger
.
disabled
=
False
if
not
config
or
not
hasattr
(
config
,
"_name_or_path"
):
raise
RuntimeError
(
"config._name_or_path is required by UnimernetModel."
)
model_path
=
config
.
_name_or_path
self
.
transform
=
UnimerSwinImageProcessor
()
self
.
tokenizer
=
TokenizerWrapper
(
AutoTokenizer
.
from_pretrained
(
model_path
))
self
.
_post_check
()
def
_post_check
(
self
):
tokenizer
=
self
.
tokenizer
if
tokenizer
.
tokenizer
.
model_max_length
!=
self
.
config
.
decoder
.
max_position_embeddings
:
warnings
.
warn
(
f
"decoder.max_position_embeddings=
{
self
.
config
.
decoder
.
max_position_embeddings
}
,"
+
f
" but tokenizer.model_max_length=
{
tokenizer
.
tokenizer
.
model_max_length
}
, will set"
+
f
" tokenizer.model_max_length to
{
self
.
config
.
decoder
.
max_position_embeddings
}
."
)
tokenizer
.
tokenizer
.
model_max_length
=
self
.
config
.
decoder
.
max_position_embeddings
assert
self
.
config
.
decoder
.
vocab_size
==
len
(
tokenizer
)
assert
self
.
config
.
decoder_start_token_id
==
tokenizer
.
bos_token_id
assert
self
.
config
.
pad_token_id
==
tokenizer
.
pad_token_id
@
classmethod
def
from_checkpoint
(
cls
,
model_path
:
str
,
model_filename
:
str
=
"pytorch_model.pth"
,
state_dict_strip_prefix
=
"model.model."
):
config
=
VisionEncoderDecoderConfig
.
from_pretrained
(
model_path
)
config
.
_name_or_path
=
model_path
config
.
encoder
=
UnimerSwinConfig
(
**
vars
(
config
.
encoder
))
config
.
decoder
=
UnimerMBartConfig
(
**
vars
(
config
.
decoder
))
encoder
=
UnimerSwinModel
(
config
.
encoder
)
decoder
=
UnimerMBartForCausalLM
(
config
.
decoder
)
model
=
cls
(
config
,
encoder
,
decoder
)
# load model weights
model_file_path
=
os
.
path
.
join
(
model_path
,
model_filename
)
checkpoint
=
torch
.
load
(
model_file_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
state_dict
=
checkpoint
[
"model"
]
if
"model"
in
checkpoint
else
checkpoint
if
not
state_dict
:
raise
RuntimeError
(
"state_dict is empty."
)
if
state_dict_strip_prefix
:
state_dict
=
{
k
[
len
(
state_dict_strip_prefix
):]
if
k
.
startswith
(
state_dict_strip_prefix
)
else
k
:
v
for
k
,
v
in
state_dict
.
items
()
}
missing_keys
,
unexpected_keys
=
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
if
len
(
unexpected_keys
)
>
0
:
warnings
.
warn
(
"Unexpected key(s) in state_dict: {}."
.
format
(
", "
.
join
(
f
'"
{
k
}
"'
for
k
in
unexpected_keys
)))
if
len
(
missing_keys
)
>
0
:
raise
RuntimeError
(
"Missing key(s) in state_dict: {}."
.
format
(
", "
.
join
(
f
'"
{
k
}
"'
for
k
in
missing_keys
)))
return
model
def
forward_bak
(
self
,
samples
):
pixel_values
,
text
=
samples
[
"image"
],
samples
[
"text_input"
]
text_inputs
=
self
.
tokenizer
.
tokenize
(
text
).
to
(
pixel_values
.
device
)
decoder_input_ids
,
decoder_attention_mask
=
text_inputs
[
"input_ids"
],
text_inputs
[
"attention_mask"
]
num_channels
=
pixel_values
.
shape
[
1
]
if
num_channels
==
1
:
pixel_values
=
pixel_values
.
repeat
(
1
,
3
,
1
,
1
)
labels
=
decoder_input_ids
*
1
labels
=
labels
.
masked_fill
(
labels
==
self
.
tokenizer
.
pad_token_id
,
-
100
)
loss
=
self
.
model
(
pixel_values
=
pixel_values
,
decoder_input_ids
=
decoder_input_ids
[:,
:
-
1
],
decoder_attention_mask
=
decoder_attention_mask
[:,
:
-
1
],
labels
=
labels
[:,
1
:],
).
loss
return
{
"loss"
:
loss
}
def
generate
(
self
,
samples
,
do_sample
:
bool
=
False
,
temperature
:
float
=
0.2
,
top_p
:
float
=
0.95
):
pixel_values
=
samples
[
"image"
]
num_channels
=
pixel_values
.
shape
[
1
]
if
num_channels
==
1
:
pixel_values
=
pixel_values
.
repeat
(
1
,
3
,
1
,
1
)
kwargs
=
{}
if
do_sample
:
kwargs
[
"temperature"
]
=
temperature
kwargs
[
"top_p"
]
=
top_p
outputs
=
super
().
generate
(
pixel_values
=
pixel_values
,
max_new_tokens
=
self
.
tokenizer
.
tokenizer
.
model_max_length
,
# required
decoder_start_token_id
=
self
.
tokenizer
.
tokenizer
.
bos_token_id
,
do_sample
=
do_sample
,
**
kwargs
,
)
outputs
=
outputs
[:,
1
:].
cpu
().
numpy
()
pred_tokens
=
self
.
tokenizer
.
detokenize
(
outputs
)
pred_str
=
self
.
tokenizer
.
token2str
(
outputs
)
fixed_str
=
[
latex_rm_whitespace
(
s
)
for
s
in
pred_str
]
return
{
"pred_ids"
:
outputs
,
"pred_tokens"
:
pred_tokens
,
"pred_str"
:
pred_str
,
"fixed_str"
:
fixed_str
}
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py
0 → 100644
View file @
41d96cd8
from
.configuration_unimer_mbart
import
UnimerMBartConfig
from
.modeling_unimer_mbart
import
UnimerMBartModel
,
UnimerMBartForCausalLM
__all__
=
[
"UnimerMBartConfig"
,
"UnimerMBartModel"
,
"UnimerMBartForCausalLM"
,
]
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py
0 → 100644
View file @
41d96cd8
# coding=utf-8
# Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""UnimerMBART model configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
class
UnimerMBartConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`MBartModel`]. It is used to instantiate an MBART
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the MBART
[facebook/mbart-large-cc25](https://huggingface.co/facebook/mbart-large-cc25) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50265):
Vocabulary size of the MBART model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MBartModel`] or [`TFMBartModel`].
d_model (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
qk_squeeze (`int`, *optional*, defaults to 2):
Squeeze ratio for query/key's output dimension. See the [UniMERNet paper](https://arxiv.org/abs/2404.15254).
Squeeze Attention maps the query and key to a lower-dimensional space without excessive loss of information,
thereby accelerating the computation of attention.
encoder_layers (`int`, *optional*, defaults to 12):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 12):
Number of decoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
scale_embedding (`bool`, *optional*, defaults to `False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models)
forced_eos_token_id (`int`, *optional*, defaults to 2):
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
`eos_token_id`.
Example:
```python
>>> from transformers import MBartConfig, MBartModel
>>> # Initializing a MBART facebook/mbart-large-cc25 style configuration
>>> configuration = MBartConfig()
>>> # Initializing a model (with random weights) from the facebook/mbart-large-cc25 style configuration
>>> model = MBartModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"unimer-mbart"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
attribute_map
=
{
"num_attention_heads"
:
"encoder_attention_heads"
,
"hidden_size"
:
"d_model"
}
def
__init__
(
self
,
vocab_size
=
50265
,
max_position_embeddings
=
1024
,
encoder_layers
=
12
,
encoder_ffn_dim
=
4096
,
encoder_attention_heads
=
16
,
decoder_layers
=
12
,
decoder_ffn_dim
=
4096
,
decoder_attention_heads
=
16
,
encoder_layerdrop
=
0.0
,
decoder_layerdrop
=
0.0
,
use_cache
=
True
,
is_encoder_decoder
=
True
,
activation_function
=
"gelu"
,
d_model
=
1024
,
qk_squeeze
=
2
,
dropout
=
0.1
,
attention_dropout
=
0.0
,
activation_dropout
=
0.0
,
init_std
=
0.02
,
classifier_dropout
=
0.0
,
scale_embedding
=
False
,
pad_token_id
=
1
,
bos_token_id
=
0
,
eos_token_id
=
2
,
forced_eos_token_id
=
2
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
d_model
=
d_model
self
.
qk_squeeze
=
qk_squeeze
self
.
encoder_ffn_dim
=
encoder_ffn_dim
self
.
encoder_layers
=
encoder_layers
self
.
encoder_attention_heads
=
encoder_attention_heads
self
.
decoder_ffn_dim
=
decoder_ffn_dim
self
.
decoder_layers
=
decoder_layers
self
.
decoder_attention_heads
=
decoder_attention_heads
self
.
dropout
=
dropout
self
.
attention_dropout
=
attention_dropout
self
.
activation_dropout
=
activation_dropout
self
.
activation_function
=
activation_function
self
.
init_std
=
init_std
self
.
encoder_layerdrop
=
encoder_layerdrop
self
.
decoder_layerdrop
=
decoder_layerdrop
self
.
classifier_dropout
=
classifier_dropout
self
.
use_cache
=
use_cache
self
.
num_hidden_layers
=
encoder_layers
self
.
scale_embedding
=
scale_embedding
# scale factor will be sqrt(d_model) if True
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
is_encoder_decoder
=
is_encoder_decoder
,
forced_eos_token_id
=
forced_eos_token_id
,
**
kwargs
,
)
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py
0 → 100644
View file @
41d96cd8
This diff is collapsed.
Click to expand it.
magic_pdf/model/sub_modules/
ocr/paddleocr/__init__
.py
→
magic_pdf/model/sub_modules/
mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart
.py
View file @
41d96cd8
File moved
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/__init__.py
0 → 100644
View file @
41d96cd8
from
.configuration_unimer_swin
import
UnimerSwinConfig
from
.modeling_unimer_swin
import
UnimerSwinModel
from
.image_processing_unimer_swin
import
UnimerSwinImageProcessor
__all__
=
[
"UnimerSwinConfig"
,
"UnimerSwinModel"
,
"UnimerSwinImageProcessor"
,
]
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment