Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
1ec5d09d
Unverified
Commit
1ec5d09d
authored
Mar 20, 2025
by
Xiaomeng Zhao
Committed by
GitHub
Mar 20, 2025
Browse files
Merge pull request #1955 from myhloli/dev
Dev push
parents
dd377537
6e35e382
Changes
32
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2813 additions
and
309 deletions
+2813
-309
docker/ascend_npu/requirements.txt
docker/ascend_npu/requirements.txt
+2
-7
docker/china/requirements.txt
docker/china/requirements.txt
+2
-7
docker/global/requirements.txt
docker/global/requirements.txt
+2
-7
magic-pdf.template.json
magic-pdf.template.json
+1
-1
magic_pdf/data/utils.py
magic_pdf/data/utils.py
+5
-9
magic_pdf/libs/pdf_image_tools.py
magic_pdf/libs/pdf_image_tools.py
+11
-6
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+5
-116
magic_pdf/model/doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+7
-6
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+4
-29
magic_pdf/model/sub_modules/language_detection/utils.py
magic_pdf/model/sub_modules/language_detection/utils.py
+2
-4
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py
...f/model/sub_modules/language_detection/yolov11/YOLOv11.py
+24
-19
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py
.../model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py
+2
-0
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
+2
-0
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
+20
-98
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py
.../model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py
+13
-0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py
..._modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py
+189
-0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py
...dules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py
+8
-0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py
...t/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py
+163
-0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py
...mernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py
+2351
-0
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py
...et/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py
+0
-0
No files found.
docker/ascend_npu/requirements.txt
View file @
1ec5d09d
...
...
@@ -7,19 +7,14 @@ numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0
scikit-learn>=1.0.2
pdfminer.six==20231228
unimernet==0.2.3
torch>=2.2.2,<=2.3.1
torchvision>=0.17.2,<=0.18.1
torch==2.3.1
torchvision==0.18.1
matplotlib
ultralytics>=8.3.48
paddleocr==2.7.3
paddlepaddle==3.0.0rc1
struct-eqtable==0.3.2
einops
accelerate
rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
openai
detectron2
docker/china/requirements.txt
View file @
1ec5d09d
...
...
@@ -7,18 +7,13 @@ numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0
scikit-learn>=1.0.2
pdfminer.six==20231228
unimernet==0.2.3
torch>=2.2.2,<=2.3.1
torchvision>=0.17.2,<=0.18.1
torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
torchvision
matplotlib
ultralytics>=8.3.48
paddleocr==2.7.3
struct-eqtable==0.3.2
einops
accelerate
rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
openai
detectron2
docker/global/requirements.txt
View file @
1ec5d09d
...
...
@@ -7,18 +7,13 @@ numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0
scikit-learn>=1.0.2
pdfminer.six==20231228
unimernet==0.2.3
torch>=2.2.2,<=2.3.1
torchvision>=0.17.2,<=0.18.1
torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
torchvision
matplotlib
ultralytics>=8.3.48
paddleocr==2.7.3
struct-eqtable==0.3.2
einops
accelerate
rapidocr-paddle>=1.4.5,<2.0.0
rapidocr-onnxruntime>=1.4.4,<2.0.0
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
openai
detectron2
magic-pdf.template.json
View file @
1ec5d09d
...
...
@@ -40,5 +40,5 @@
"enable"
:
false
}
},
"config_version"
:
"1.
1.1
"
"config_version"
:
"1.
2.0
"
}
\ No newline at end of file
magic_pdf/data/utils.py
View file @
1ec5d09d
...
...
@@ -8,10 +8,8 @@ import fitz
import
numpy
as
np
from
loguru
import
logger
from
magic_pdf.utils.annotations
import
ImportPIL
@
ImportPIL
def
fitz_doc_to_image
(
doc
,
dpi
=
200
)
->
dict
:
"""Convert fitz.Document to image, Then convert the image to numpy array.
...
...
@@ -22,7 +20,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
Returns:
dict: {'img': numpy array, 'width': width, 'height': height }
"""
from
PIL
import
Image
mat
=
fitz
.
Matrix
(
dpi
/
72
,
dpi
/
72
)
pm
=
doc
.
get_pixmap
(
matrix
=
mat
,
alpha
=
False
)
...
...
@@ -30,16 +27,14 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
if
pm
.
width
>
4500
or
pm
.
height
>
4500
:
pm
=
doc
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
img
=
Image
.
frombytes
(
'RGB'
,
(
pm
.
width
,
pm
.
height
),
pm
.
samples
)
img
=
np
.
array
(
img
)
# Convert pixmap samples directly to numpy array
img
=
np
.
frombuffer
(
pm
.
samples
,
dtype
=
np
.
uint8
).
reshape
(
pm
.
height
,
pm
.
width
,
3
)
img_dict
=
{
'img'
:
img
,
'width'
:
pm
.
width
,
'height'
:
pm
.
height
}
return
img_dict
@
ImportPIL
def
load_images_from_pdf
(
pdf_bytes
:
bytes
,
dpi
=
200
,
start_page_id
=
0
,
end_page_id
=
None
)
->
list
:
from
PIL
import
Image
images
=
[]
with
fitz
.
open
(
'pdf'
,
pdf_bytes
)
as
doc
:
pdf_page_num
=
doc
.
page_count
...
...
@@ -62,8 +57,9 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
if
pm
.
width
>
4500
or
pm
.
height
>
4500
:
pm
=
page
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
img
=
Image
.
frombytes
(
'RGB'
,
(
pm
.
width
,
pm
.
height
),
pm
.
samples
)
img
=
np
.
array
(
img
)
# Convert pixmap samples directly to numpy array
img
=
np
.
frombuffer
(
pm
.
samples
,
dtype
=
np
.
uint8
).
reshape
(
pm
.
height
,
pm
.
width
,
3
)
img_dict
=
{
'img'
:
img
,
'width'
:
pm
.
width
,
'height'
:
pm
.
height
}
else
:
img_dict
=
{
'img'
:
[],
'width'
:
0
,
'height'
:
0
}
...
...
magic_pdf/libs/pdf_image_tools.py
View file @
1ec5d09d
...
...
@@ -44,14 +44,19 @@ def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
# 截取图片
pix
=
page
.
get_pixmap
(
clip
=
rect
,
matrix
=
zoom
)
# 将字节数据转换为文件对象
image_file
=
BytesIO
(
pix
.
tobytes
(
output
=
'png'
))
# 使用 Pillow 打开图像
pil_image
=
Image
.
open
(
image_file
)
if
mode
==
"cv2"
:
image_result
=
cv2
.
cvtColor
(
np
.
asarray
(
pil_image
),
cv2
.
COLOR_RGB2BGR
)
# 直接转换为numpy数组供cv2使用
img_array
=
np
.
frombuffer
(
pix
.
samples
,
dtype
=
np
.
uint8
).
reshape
(
pix
.
height
,
pix
.
width
,
pix
.
n
)
# PyMuPDF使用RGB顺序,而cv2使用BGR顺序
if
pix
.
n
==
3
or
pix
.
n
==
4
:
image_result
=
cv2
.
cvtColor
(
img_array
,
cv2
.
COLOR_RGB2BGR
)
else
:
image_result
=
img_array
elif
mode
==
"pillow"
:
image_result
=
pil_image
# 将字节数据转换为文件对象
image_file
=
BytesIO
(
pix
.
tobytes
(
output
=
'png'
))
# 使用 Pillow 打开图像
image_result
=
Image
.
open
(
image_file
)
else
:
raise
ValueError
(
f
"mode:
{
mode
}
is not supported."
)
...
...
magic_pdf/model/batch_analyze.py
View file @
1ec5d09d
import
time
import
cv2
import
numpy
as
np
import
torch
from
loguru
import
logger
from
PIL
import
Image
from
magic_pdf.config.constants
import
MODEL_NAME
# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
# from magic_pdf.data.dataset import Dataset
# from magic_pdf.libs.clean_memory import clean_memory
# from magic_pdf.libs.config_reader import get_device
# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from
magic_pdf.model.pdf_extract_kit
import
CustomPEKModel
from
magic_pdf.model.sub_modules.model_utils
import
(
clean_vram
,
crop_img
,
get_res_list_from_layout_res
)
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
# from magic_pdf.operators.models import InferenceResult
YOLO_LAYOUT_BASE_BATCH_SIZE
=
1
MFD_BASE_BATCH_SIZE
=
1
...
...
@@ -31,7 +23,6 @@ class BatchAnalyze:
def
__call__
(
self
,
images
:
list
)
->
list
:
images_layout_res
=
[]
layout_start_time
=
time
.
time
()
if
self
.
model
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
# layoutlmv3
...
...
@@ -41,36 +32,14 @@ class BatchAnalyze:
elif
self
.
model
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
# doclayout_yolo
layout_images
=
[]
modified_images
=
[]
for
image_index
,
image
in
enumerate
(
images
):
pil_img
=
Image
.
fromarray
(
image
)
# width, height = pil_img.size
# if height > width:
# input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
# new_image, useful_list = crop_img(
# input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
# )
# layout_images.append(new_image)
# modified_images.append([image_index, useful_list])
# else:
layout_images
.
append
(
pil_img
)
layout_images
.
append
(
image
)
images_layout_res
+=
self
.
model
.
layout_model
.
batch_predict
(
# layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
layout_images
,
YOLO_LAYOUT_BASE_BATCH_SIZE
)
for
image_index
,
useful_list
in
modified_images
:
for
res
in
images_layout_res
[
image_index
]:
for
i
in
range
(
len
(
res
[
'poly'
])):
if
i
%
2
==
0
:
res
[
'poly'
][
i
]
=
(
res
[
'poly'
][
i
]
-
useful_list
[
0
]
+
useful_list
[
2
]
)
else
:
res
[
'poly'
][
i
]
=
(
res
[
'poly'
][
i
]
-
useful_list
[
1
]
+
useful_list
[
3
]
)
logger
.
info
(
f
'layout time:
{
round
(
time
.
time
()
-
layout_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
'
)
...
...
@@ -111,7 +80,7 @@ class BatchAnalyze:
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for
index
in
range
(
len
(
images
)):
layout_res
=
images_layout_res
[
index
]
pil_img
=
Image
.
fromarray
(
images
[
index
]
)
np_array_img
=
images
[
index
]
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
=
(
get_res_list_from_layout_res
(
layout_res
)
...
...
@@ -121,14 +90,14 @@ class BatchAnalyze:
# Process each area that requires OCR processing
for
res
in
ocr_res_list
:
new_image
,
useful_list
=
crop_img
(
res
,
pil
_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
res
,
np_array
_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
# OCR recognition
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
)
,
cv2
.
COLOR_RGB2BGR
)
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
if
self
.
model
.
apply_ocr
:
ocr_res
=
self
.
model
.
ocr_model
.
ocr
(
...
...
@@ -150,7 +119,7 @@ class BatchAnalyze:
if
self
.
model
.
apply_table
:
table_start
=
time
.
time
()
for
res
in
table_res_list
:
new_image
,
_
=
crop_img
(
res
,
pil
_img
)
new_image
,
_
=
crop_img
(
res
,
np_array
_img
)
single_table_start_time
=
time
.
time
()
html_code
=
None
if
self
.
model
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
...
...
@@ -197,83 +166,3 @@ class BatchAnalyze:
logger
.
info
(
f
'table time:
{
round
(
table_time
,
2
)
}
, image num:
{
table_count
}
'
)
return
images_layout_res
# def doc_batch_analyze(
# dataset: Dataset,
# ocr: bool = False,
# show_log: bool = False,
# start_page_id=0,
# end_page_id=None,
# lang=None,
# layout_model=None,
# formula_enable=None,
# table_enable=None,
# batch_ratio: int | None = None,
# ) -> InferenceResult:
# """Perform batch analysis on a document dataset.
#
# Args:
# dataset (Dataset): The dataset containing document pages to be analyzed.
# ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
# show_log (bool, optional): Flag to enable logging. Defaults to False.
# start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
# end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
# lang (str, optional): Language for OCR. Defaults to None.
# layout_model (optional): Layout model to be used for analysis. Defaults to None.
# formula_enable (optional): Flag to enable formula detection. Defaults to None.
# table_enable (optional): Flag to enable table detection. Defaults to None.
# batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
#
# Raises:
# CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
#
# Returns:
# InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
# """
#
# if not torch.cuda.is_available():
# raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
#
# lang = None if lang == '' else lang
# # TODO: auto detect batch size
# batch_ratio = 1 if batch_ratio is None else batch_ratio
# end_page_id = end_page_id if end_page_id else len(dataset)
#
# model_manager = ModelSingleton()
# custom_model: CustomPEKModel = model_manager.get_model(
# ocr, show_log, lang, layout_model, formula_enable, table_enable
# )
# batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
#
# model_json = []
#
# # batch analyze
# images = []
# for index in range(len(dataset)):
# if start_page_id <= index <= end_page_id:
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# images.append(img_dict['img'])
# analyze_result = batch_model(images)
#
# for index in range(len(dataset)):
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# page_width = img_dict['width']
# page_height = img_dict['height']
# if start_page_id <= index <= end_page_id:
# result = analyze_result.pop(0)
# else:
# result = []
#
# page_info = {'page_no': index, 'height': page_height, 'width': page_width}
# page_dict = {'layout_dets': result, 'page_info': page_info}
# model_json.append(page_dict)
#
# # TODO: clean memory when gpu memory is not enough
# clean_memory_start_time = time.time()
# clean_memory(get_device())
# logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
#
# return InferenceResult(model_json, dataset)
magic_pdf/model/doc_analyze_by_custom_model.py
View file @
1ec5d09d
...
...
@@ -256,27 +256,28 @@ def may_batch_image_analyze(
batch_ratio
=
1
device
=
get_device
()
npu_support
=
False
if
str
(
device
).
startswith
(
'npu'
):
import
torch_npu
if
torch_npu
.
npu
.
is_available
():
npu_support
=
True
torch
.
npu
.
set_compile_mode
(
jit_compile
=
False
)
if
torch
.
cuda
.
is_available
()
and
device
!=
'cpu'
or
npu_support
:
if
str
(
device
).
startswith
(
'npu'
)
or
str
(
device
).
startswith
(
'cuda'
)
:
gpu_memory
=
int
(
os
.
getenv
(
'VIRTUAL_VRAM_SIZE'
,
round
(
get_vram
(
device
))))
if
gpu_memory
is
not
None
and
gpu_memory
>=
8
:
if
gpu_memory
is
not
None
:
if
gpu_memory
>=
20
:
batch_ratio
=
16
elif
gpu_memory
>=
15
:
batch_ratio
=
8
elif
gpu_memory
>=
10
:
batch_ratio
=
4
el
se
:
el
if
gpu_memory
>=
7
:
batch_ratio
=
2
else
:
batch_ratio
=
1
logger
.
info
(
f
'gpu_memory:
{
gpu_memory
}
GB, batch_ratio:
{
batch_ratio
}
'
)
batch_analyze
=
True
elif
str
(
device
).
startswith
(
'mps'
):
batch_analyze
=
True
doc_analyze_start
=
time
.
time
()
if
batch_analyze
:
...
...
magic_pdf/model/pdf_extract_kit.py
View file @
1ec5d09d
...
...
@@ -3,11 +3,9 @@ import os
import
time
import
cv2
import
numpy
as
np
import
torch
import
yaml
from
loguru
import
logger
from
PIL
import
Image
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
...
...
@@ -120,7 +118,7 @@ class CustomPEKModel:
atom_model_name
=
AtomicModel
.
MFR
,
mfr_weight_dir
=
mfr_weight_dir
,
mfr_cfg_path
=
mfr_cfg_path
,
device
=
'cpu'
if
str
(
self
.
device
).
startswith
(
"mps"
)
else
self
.
device
,
device
=
self
.
device
,
)
# 初始化layout模型
...
...
@@ -174,11 +172,6 @@ class CustomPEKModel:
logger
.
info
(
'DocAnalysis init done!'
)
def
__call__
(
self
,
image
):
pil_img
=
Image
.
fromarray
(
image
)
width
,
height
=
pil_img
.
size
# logger.info(f'width: {width}, height: {height}')
# layout检测
layout_start
=
time
.
time
()
layout_res
=
[]
...
...
@@ -186,24 +179,6 @@ class CustomPEKModel:
# layoutlmv3
layout_res
=
self
.
layout_model
(
image
,
ignore_catids
=
[])
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
# doclayout_yolo
# if height > width:
# input_res = {"poly":[0,0,width,0,width,height,0,height]}
# new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
# paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# layout_res = self.layout_model.predict(new_image)
# for res in layout_res:
# p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
# p1 = p1 - paste_x + xmin
# p2 = p2 - paste_y + ymin
# p3 = p3 - paste_x + xmin
# p4 = p4 - paste_y + ymin
# p5 = p5 - paste_x + xmin
# p6 = p6 - paste_y + ymin
# p7 = p7 - paste_x + xmin
# p8 = p8 - paste_y + ymin
# res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
# else:
layout_res
=
self
.
layout_model
.
predict
(
image
)
layout_cost
=
round
(
time
.
time
()
-
layout_start
,
2
)
...
...
@@ -234,11 +209,11 @@ class CustomPEKModel:
ocr_start
=
time
.
time
()
# Process each area that requires OCR processing
for
res
in
ocr_res_list
:
new_image
,
useful_list
=
crop_img
(
res
,
pil_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
new_image
,
useful_list
=
crop_img
(
res
,
image
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
# OCR recognition
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
)
,
cv2
.
COLOR_RGB2BGR
)
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
if
self
.
apply_ocr
:
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
...
...
@@ -260,7 +235,7 @@ class CustomPEKModel:
if
self
.
apply_table
:
table_start
=
time
.
time
()
for
res
in
table_res_list
:
new_image
,
_
=
crop_img
(
res
,
pil_img
)
new_image
,
_
=
crop_img
(
res
,
image
)
single_table_start_time
=
time
.
time
()
html_code
=
None
if
self
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
...
...
magic_pdf/model/sub_modules/language_detection/utils.py
View file @
1ec5d09d
...
...
@@ -3,8 +3,6 @@ import os
from
pathlib
import
Path
import
yaml
from
PIL
import
Image
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
from
magic_pdf.config.constants
import
MODEL_NAME
...
...
@@ -42,7 +40,7 @@ def get_text_images(simple_images):
)
text_images
=
[]
for
simple_image
in
simple_images
:
image
=
Image
.
fromarray
(
simple_image
[
'img'
]
)
image
=
simple_image
[
'img'
]
layout_res
=
temp_layout_model
.
predict
(
image
)
# 给textblock截图
for
res
in
layout_res
:
...
...
@@ -51,7 +49,7 @@ def get_text_images(simple_images):
# 初步清洗(宽和高都小于100)
if
x2
-
x1
<
100
and
y2
-
y1
<
100
:
continue
text_images
.
append
(
image
.
crop
((
x1
,
y1
,
x2
,
y2
))
)
text_images
.
append
(
image
[
y1
:
y2
,
x1
:
x2
]
)
return
text_images
...
...
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py
View file @
1ec5d09d
...
...
@@ -2,9 +2,9 @@
import
time
from
collections
import
Counter
from
uuid
import
uuid4
import
cv2
import
numpy
as
np
import
torch
from
PIL
import
Image
from
loguru
import
logger
from
ultralytics
import
YOLO
...
...
@@ -29,7 +29,7 @@ def split_images(image, result_images=None):
if
result_images
is
None
:
result_images
=
[]
width
,
height
=
image
.
s
ize
height
,
width
=
image
.
s
hape
[:
2
]
long_side
=
max
(
width
,
height
)
# 获取较长边长度
if
long_side
<=
400
:
...
...
@@ -44,16 +44,14 @@ def split_images(image, result_images=None):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if
x
+
new_long_side
>
width
:
continue
box
=
(
x
,
0
,
x
+
new_long_side
,
height
)
sub_image
=
image
.
crop
(
box
)
sub_image
=
image
[
0
:
height
,
x
:
x
+
new_long_side
]
sub_images
.
append
(
sub_image
)
else
:
# 如果高度是较长边
for
y
in
range
(
0
,
height
,
new_long_side
):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if
y
+
new_long_side
>
height
:
continue
box
=
(
0
,
y
,
width
,
y
+
new_long_side
)
sub_image
=
image
.
crop
(
box
)
sub_image
=
image
[
y
:
y
+
new_long_side
,
0
:
width
]
sub_images
.
append
(
sub_image
)
for
sub_image
in
sub_images
:
...
...
@@ -64,24 +62,32 @@ def split_images(image, result_images=None):
def
resize_images_to_224
(
image
):
"""
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小。
Works directly with NumPy arrays.
"""
try
:
width
,
height
=
image
.
size
height
,
width
=
image
.
shape
[:
2
]
if
width
<
224
or
height
<
224
:
new_image
=
Image
.
new
(
'RGB'
,
(
224
,
224
),
(
0
,
0
,
0
))
paste_x
=
(
224
-
width
)
//
2
paste_y
=
(
224
-
height
)
//
2
new_image
.
paste
(
image
,
(
paste_x
,
paste_y
))
# Create black background
new_image
=
np
.
zeros
((
224
,
224
,
3
),
dtype
=
np
.
uint8
)
# Calculate paste position (ensure they're not negative)
paste_x
=
max
(
0
,
(
224
-
width
)
//
2
)
paste_y
=
max
(
0
,
(
224
-
height
)
//
2
)
# Make sure we don't exceed the boundaries of new_image
paste_width
=
min
(
width
,
224
)
paste_height
=
min
(
height
,
224
)
# Paste original image onto black background
new_image
[
paste_y
:
paste_y
+
paste_height
,
paste_x
:
paste_x
+
paste_width
]
=
image
[:
paste_height
,
:
paste_width
]
image
=
new_image
else
:
image
=
image
.
resize
((
224
,
224
),
Image
.
Resampling
.
LANCZOS
)
# Resize using cv2
image
=
cv2
.
resize
(
image
,
(
224
,
224
),
interpolation
=
cv2
.
INTER_LANCZOS4
)
# uuid = str(uuid4())
# image.save(f"/tmp/{uuid}.jpg")
return
image
except
Exception
as
e
:
logger
.
exception
(
e
)
logger
.
exception
(
f
"Error in resize_images_to_224:
{
e
}
"
)
return
None
class
YOLOv11LangDetModel
(
object
):
...
...
@@ -96,8 +102,7 @@ class YOLOv11LangDetModel(object):
def
do_detect
(
self
,
images
:
list
):
all_images
=
[]
for
image
in
images
:
width
,
height
=
image
.
size
# logger.info(f"image size: {width} x {height}")
height
,
width
=
image
.
shape
[:
2
]
if
width
<
100
and
height
<
100
:
continue
temp_images
=
split_images
(
image
)
...
...
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py
View file @
1ec5d09d
...
...
@@ -4,6 +4,8 @@ from doclayout_yolo import YOLOv10
class
DocLayoutYOLOModel
(
object
):
def
__init__
(
self
,
weight
,
device
):
self
.
model
=
YOLOv10
(
weight
)
if
not
device
.
startswith
(
"cpu"
):
self
.
model
.
half
()
self
.
device
=
device
def
predict
(
self
,
image
):
...
...
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py
View file @
1ec5d09d
...
...
@@ -4,6 +4,8 @@ from ultralytics import YOLO
class
YOLOv8MFDModel
(
object
):
def
__init__
(
self
,
weight
,
device
=
"cpu"
):
self
.
mfd_model
=
YOLO
(
weight
)
if
not
device
.
startswith
(
"cpu"
):
self
.
mfd_model
.
half
()
self
.
device
=
device
def
predict
(
self
,
image
):
...
...
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
View file @
1ec5d09d
import
argparse
import
os
import
re
import
torch
import
unimernet.tasks
as
tasks
from
PIL
import
Image
from
torch.utils.data
import
DataLoader
,
Dataset
from
torchvision
import
transforms
from
unimernet.common.config
import
Config
from
unimernet.processors
import
load_processor
class
MathDataset
(
Dataset
):
...
...
@@ -20,55 +11,25 @@ class MathDataset(Dataset):
return
len
(
self
.
image_paths
)
def
__getitem__
(
self
,
idx
):
# if not pil image, then convert to pil image
if
isinstance
(
self
.
image_paths
[
idx
],
str
):
raw_image
=
Image
.
open
(
self
.
image_paths
[
idx
])
else
:
raw_image
=
self
.
image_paths
[
idx
]
raw_image
=
self
.
image_paths
[
idx
]
if
self
.
transform
:
image
=
self
.
transform
(
raw_image
)
return
image
def
latex_rm_whitespace
(
s
:
str
):
"""Remove unnecessary whitespace from LaTeX code."""
text_reg
=
r
"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
letter
=
"[a-zA-Z]"
noletter
=
"[\W_^\d]"
names
=
[
x
[
0
].
replace
(
" "
,
""
)
for
x
in
re
.
findall
(
text_reg
,
s
)]
s
=
re
.
sub
(
text_reg
,
lambda
match
:
str
(
names
.
pop
(
0
)),
s
)
news
=
s
while
True
:
s
=
news
news
=
re
.
sub
(
r
"(?!\\ )(%s)\s+?(%s)"
%
(
noletter
,
noletter
),
r
"\1\2"
,
s
)
news
=
re
.
sub
(
r
"(?!\\ )(%s)\s+?(%s)"
%
(
noletter
,
letter
),
r
"\1\2"
,
news
)
news
=
re
.
sub
(
r
"(%s)\s+?(%s)"
%
(
letter
,
noletter
),
r
"\1\2"
,
news
)
if
news
==
s
:
break
return
s
class
UnimernetModel
(
object
):
def
__init__
(
self
,
weight_dir
,
cfg_path
,
_device_
=
"cpu"
):
args
=
argparse
.
Namespace
(
cfg_path
=
cfg_path
,
options
=
None
)
cfg
=
Config
(
args
)
cfg
.
config
.
model
.
pretrained
=
os
.
path
.
join
(
weight_dir
,
"pytorch_model.pth"
)
cfg
.
config
.
model
.
model_config
.
model_name
=
weight_dir
cfg
.
config
.
model
.
tokenizer_config
.
path
=
weight_dir
task
=
tasks
.
setup_task
(
cfg
)
self
.
model
=
task
.
build_model
(
cfg
)
from
.unimernet_hf
import
UnimernetModel
if
_device_
.
startswith
(
"mps"
):
self
.
model
=
UnimernetModel
.
from_pretrained
(
weight_dir
,
attn_implementation
=
"eager"
)
else
:
self
.
model
=
UnimernetModel
.
from_pretrained
(
weight_dir
)
self
.
device
=
_device_
self
.
model
.
to
(
_device_
)
if
not
_device_
.
startswith
(
"cpu"
):
self
.
model
=
self
.
model
.
to
(
dtype
=
torch
.
float16
)
self
.
model
.
eval
()
vis_processor
=
load_processor
(
"formula_image_eval"
,
cfg
.
config
.
datasets
.
formula_rec_eval
.
vis_processor
.
eval
,
)
self
.
mfr_transform
=
transforms
.
Compose
(
[
vis_processor
,
]
)
def
predict
(
self
,
mfd_res
,
image
):
formula_list
=
[]
...
...
@@ -84,62 +45,22 @@ class UnimernetModel(object):
"latex"
:
""
,
}
formula_list
.
append
(
new_item
)
pil_img
=
Image
.
fromarray
(
image
)
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
ymax
))
bbox_img
=
image
[
ymin
:
ymax
,
xmin
:
xmax
]
mf_image_list
.
append
(
bbox_img
)
dataset
=
MathDataset
(
mf_image_list
,
transform
=
self
.
m
fr_
transform
)
dataset
=
MathDataset
(
mf_image_list
,
transform
=
self
.
m
odel
.
transform
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
32
,
num_workers
=
0
)
mfr_res
=
[]
for
mf_img
in
dataloader
:
mf_img
=
mf_img
.
to
(
dtype
=
self
.
model
.
dtype
)
mf_img
=
mf_img
.
to
(
self
.
device
)
with
torch
.
no_grad
():
output
=
self
.
model
.
generate
({
"image"
:
mf_img
})
mfr_res
.
extend
(
output
[
"
pr
ed_str"
])
mfr_res
.
extend
(
output
[
"
fix
ed_str"
])
for
res
,
latex
in
zip
(
formula_list
,
mfr_res
):
res
[
"latex"
]
=
latex
_rm_whitespace
(
latex
)
res
[
"latex"
]
=
latex
return
formula_list
# def batch_predict(
# self, images_mfd_res: list, images: list, batch_size: int = 64
# ) -> list:
# images_formula_list = []
# mf_image_list = []
# backfill_list = []
# for image_index in range(len(images_mfd_res)):
# mfd_res = images_mfd_res[image_index]
# pil_img = Image.fromarray(images[image_index])
# formula_list = []
#
# for xyxy, conf, cla in zip(
# mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
# ):
# xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
# new_item = {
# "category_id": 13 + int(cla.item()),
# "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
# "score": round(float(conf.item()), 2),
# "latex": "",
# }
# formula_list.append(new_item)
# bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
# mf_image_list.append(bbox_img)
#
# images_formula_list.append(formula_list)
# backfill_list += formula_list
#
# dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# mfr_res = []
# for mf_img in dataloader:
# mf_img = mf_img.to(self.device)
# with torch.no_grad():
# output = self.model.generate({"image": mf_img})
# mfr_res.extend(output["pred_str"])
# for res, latex in zip(backfill_list, mfr_res):
# res["latex"] = latex_rm_whitespace(latex)
# return images_formula_list
def
batch_predict
(
self
,
images_mfd_res
:
list
,
images
:
list
,
batch_size
:
int
=
64
)
->
list
:
images_formula_list
=
[]
mf_image_list
=
[]
...
...
@@ -149,7 +70,7 @@ class UnimernetModel(object):
# Collect images with their original indices
for
image_index
in
range
(
len
(
images_mfd_res
)):
mfd_res
=
images_mfd_res
[
image_index
]
pil_img
=
Image
.
fromarray
(
images
[
image_index
]
)
np_array_image
=
images
[
image_index
]
formula_list
=
[]
for
idx
,
(
xyxy
,
conf
,
cla
)
in
enumerate
(
zip
(
...
...
@@ -163,7 +84,7 @@ class UnimernetModel(object):
"latex"
:
""
,
}
formula_list
.
append
(
new_item
)
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
y
max
))
bbox_img
=
np_array_image
[
ymin
:
ymax
,
xmin
:
x
max
]
area
=
(
xmax
-
xmin
)
*
(
ymax
-
ymin
)
curr_idx
=
len
(
mf_image_list
)
...
...
@@ -182,22 +103,23 @@ class UnimernetModel(object):
index_mapping
=
{
new_idx
:
old_idx
for
new_idx
,
old_idx
in
enumerate
(
sorted_indices
)}
# Create dataset with sorted images
dataset
=
MathDataset
(
sorted_images
,
transform
=
self
.
m
fr_
transform
)
dataset
=
MathDataset
(
sorted_images
,
transform
=
self
.
m
odel
.
transform
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
num_workers
=
0
)
# Process batches and store results
mfr_res
=
[]
for
mf_img
in
dataloader
:
mf_img
=
mf_img
.
to
(
dtype
=
self
.
model
.
dtype
)
mf_img
=
mf_img
.
to
(
self
.
device
)
with
torch
.
no_grad
():
output
=
self
.
model
.
generate
({
"image"
:
mf_img
})
mfr_res
.
extend
(
output
[
"
pr
ed_str"
])
mfr_res
.
extend
(
output
[
"
fix
ed_str"
])
# Restore original order
unsorted_results
=
[
""
]
*
len
(
mfr_res
)
for
new_idx
,
latex
in
enumerate
(
mfr_res
):
original_idx
=
index_mapping
[
new_idx
]
unsorted_results
[
original_idx
]
=
latex
_rm_whitespace
(
latex
)
unsorted_results
[
original_idx
]
=
latex
# Fill results back
for
res
,
latex
in
zip
(
backfill_list
,
unsorted_results
):
...
...
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/__init__.py
0 → 100644
View file @
1ec5d09d
from
.unimer_swin
import
UnimerSwinConfig
,
UnimerSwinModel
,
UnimerSwinImageProcessor
from
.unimer_mbart
import
UnimerMBartConfig
,
UnimerMBartModel
,
UnimerMBartForCausalLM
from
.modeling_unimernet
import
UnimernetModel
__all__
=
[
"UnimerSwinConfig"
,
"UnimerSwinModel"
,
"UnimerSwinImageProcessor"
,
"UnimerMBartConfig"
,
"UnimerMBartModel"
,
"UnimerMBartForCausalLM"
,
"UnimernetModel"
,
]
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/modeling_unimernet.py
0 → 100644
View file @
1ec5d09d
import
os
import
re
import
warnings
from
typing
import
Optional
import
torch
from
ftfy
import
fix_text
from
transformers
import
AutoConfig
,
AutoModel
,
AutoModelForCausalLM
,
AutoTokenizer
,
PretrainedConfig
,
PreTrainedModel
from
transformers
import
VisionEncoderDecoderConfig
,
VisionEncoderDecoderModel
from
transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder
import
logger
as
base_model_logger
from
.unimer_swin
import
UnimerSwinConfig
,
UnimerSwinModel
,
UnimerSwinImageProcessor
from
.unimer_mbart
import
UnimerMBartConfig
,
UnimerMBartForCausalLM
AutoConfig
.
register
(
UnimerSwinConfig
.
model_type
,
UnimerSwinConfig
)
AutoConfig
.
register
(
UnimerMBartConfig
.
model_type
,
UnimerMBartConfig
)
AutoModel
.
register
(
UnimerSwinConfig
,
UnimerSwinModel
)
AutoModelForCausalLM
.
register
(
UnimerMBartConfig
,
UnimerMBartForCausalLM
)
# TODO: rewrite tokenizer
class
TokenizerWrapper
:
def
__init__
(
self
,
tokenizer
):
self
.
tokenizer
=
tokenizer
self
.
pad_token_id
=
self
.
tokenizer
.
pad_token_id
self
.
bos_token_id
=
self
.
tokenizer
.
bos_token_id
self
.
eos_token_id
=
self
.
tokenizer
.
eos_token_id
def
__len__
(
self
):
return
len
(
self
.
tokenizer
)
def
tokenize
(
self
,
text
,
**
kwargs
):
return
self
.
tokenizer
(
text
,
return_token_type_ids
=
False
,
return_tensors
=
"pt"
,
padding
=
"longest"
,
truncation
=
True
,
**
kwargs
,
)
def
token2str
(
self
,
tokens
)
->
list
:
generated_text
=
self
.
tokenizer
.
batch_decode
(
tokens
,
skip_special_tokens
=
True
)
generated_text
=
[
fix_text
(
text
)
for
text
in
generated_text
]
return
generated_text
def
detokenize
(
self
,
tokens
):
toks
=
[
self
.
tokenizer
.
convert_ids_to_tokens
(
tok
)
for
tok
in
tokens
]
for
b
in
range
(
len
(
toks
)):
for
i
in
reversed
(
range
(
len
(
toks
[
b
]))):
if
toks
[
b
][
i
]
is
None
:
toks
[
b
][
i
]
=
''
toks
[
b
][
i
]
=
toks
[
b
][
i
].
replace
(
'Ġ'
,
' '
).
strip
()
if
toks
[
b
][
i
]
in
([
self
.
tokenizer
.
bos_token
,
self
.
tokenizer
.
eos_token
,
self
.
tokenizer
.
pad_token
]):
del
toks
[
b
][
i
]
return
toks
def
latex_rm_whitespace
(
s
:
str
):
"""Remove unnecessary whitespace from LaTeX code.
"""
text_reg
=
r
'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
letter
=
r
'[a-zA-Z]'
noletter
=
r
'[\W_^\d]'
names
=
[
x
[
0
].
replace
(
' '
,
''
)
for
x
in
re
.
findall
(
text_reg
,
s
)]
s
=
re
.
sub
(
text_reg
,
lambda
_
:
str
(
names
.
pop
(
0
)),
s
)
news
=
s
while
True
:
s
=
news
news
=
re
.
sub
(
r
'(?!\\ )(%s)\s+?(%s)'
%
(
noletter
,
noletter
),
r
'\1\2'
,
s
)
news
=
re
.
sub
(
r
'(?!\\ )(%s)\s+?(%s)'
%
(
noletter
,
letter
),
r
'\1\2'
,
news
)
news
=
re
.
sub
(
r
'(%s)\s+?(%s)'
%
(
letter
,
noletter
),
r
'\1\2'
,
news
)
if
news
==
s
:
break
return
s
class
UnimernetModel
(
VisionEncoderDecoderModel
):
def
__init__
(
self
,
config
:
Optional
[
PretrainedConfig
]
=
None
,
encoder
:
Optional
[
PreTrainedModel
]
=
None
,
decoder
:
Optional
[
PreTrainedModel
]
=
None
,
):
# VisionEncoderDecoderModel's checking log has bug, disable for temp.
base_model_logger
.
disabled
=
True
try
:
super
().
__init__
(
config
,
encoder
,
decoder
)
finally
:
base_model_logger
.
disabled
=
False
if
not
config
or
not
hasattr
(
config
,
"_name_or_path"
):
raise
RuntimeError
(
"config._name_or_path is required by UnimernetModel."
)
model_path
=
config
.
_name_or_path
self
.
transform
=
UnimerSwinImageProcessor
()
self
.
tokenizer
=
TokenizerWrapper
(
AutoTokenizer
.
from_pretrained
(
model_path
))
self
.
_post_check
()
def
_post_check
(
self
):
tokenizer
=
self
.
tokenizer
if
tokenizer
.
tokenizer
.
model_max_length
!=
self
.
config
.
decoder
.
max_position_embeddings
:
warnings
.
warn
(
f
"decoder.max_position_embeddings=
{
self
.
config
.
decoder
.
max_position_embeddings
}
,"
+
f
" but tokenizer.model_max_length=
{
tokenizer
.
tokenizer
.
model_max_length
}
, will set"
+
f
" tokenizer.model_max_length to
{
self
.
config
.
decoder
.
max_position_embeddings
}
."
)
tokenizer
.
tokenizer
.
model_max_length
=
self
.
config
.
decoder
.
max_position_embeddings
assert
self
.
config
.
decoder
.
vocab_size
==
len
(
tokenizer
)
assert
self
.
config
.
decoder_start_token_id
==
tokenizer
.
bos_token_id
assert
self
.
config
.
pad_token_id
==
tokenizer
.
pad_token_id
@
classmethod
def
from_checkpoint
(
cls
,
model_path
:
str
,
model_filename
:
str
=
"pytorch_model.pth"
,
state_dict_strip_prefix
=
"model.model."
):
config
=
VisionEncoderDecoderConfig
.
from_pretrained
(
model_path
)
config
.
_name_or_path
=
model_path
config
.
encoder
=
UnimerSwinConfig
(
**
vars
(
config
.
encoder
))
config
.
decoder
=
UnimerMBartConfig
(
**
vars
(
config
.
decoder
))
encoder
=
UnimerSwinModel
(
config
.
encoder
)
decoder
=
UnimerMBartForCausalLM
(
config
.
decoder
)
model
=
cls
(
config
,
encoder
,
decoder
)
# load model weights
model_file_path
=
os
.
path
.
join
(
model_path
,
model_filename
)
checkpoint
=
torch
.
load
(
model_file_path
,
map_location
=
"cpu"
,
weights_only
=
True
)
state_dict
=
checkpoint
[
"model"
]
if
"model"
in
checkpoint
else
checkpoint
if
not
state_dict
:
raise
RuntimeError
(
"state_dict is empty."
)
if
state_dict_strip_prefix
:
state_dict
=
{
k
[
len
(
state_dict_strip_prefix
):]
if
k
.
startswith
(
state_dict_strip_prefix
)
else
k
:
v
for
k
,
v
in
state_dict
.
items
()
}
missing_keys
,
unexpected_keys
=
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
if
len
(
unexpected_keys
)
>
0
:
warnings
.
warn
(
"Unexpected key(s) in state_dict: {}."
.
format
(
", "
.
join
(
f
'"
{
k
}
"'
for
k
in
unexpected_keys
)))
if
len
(
missing_keys
)
>
0
:
raise
RuntimeError
(
"Missing key(s) in state_dict: {}."
.
format
(
", "
.
join
(
f
'"
{
k
}
"'
for
k
in
missing_keys
)))
return
model
def
forward_bak
(
self
,
samples
):
pixel_values
,
text
=
samples
[
"image"
],
samples
[
"text_input"
]
text_inputs
=
self
.
tokenizer
.
tokenize
(
text
).
to
(
pixel_values
.
device
)
decoder_input_ids
,
decoder_attention_mask
=
text_inputs
[
"input_ids"
],
text_inputs
[
"attention_mask"
]
num_channels
=
pixel_values
.
shape
[
1
]
if
num_channels
==
1
:
pixel_values
=
pixel_values
.
repeat
(
1
,
3
,
1
,
1
)
labels
=
decoder_input_ids
*
1
labels
=
labels
.
masked_fill
(
labels
==
self
.
tokenizer
.
pad_token_id
,
-
100
)
loss
=
self
.
model
(
pixel_values
=
pixel_values
,
decoder_input_ids
=
decoder_input_ids
[:,
:
-
1
],
decoder_attention_mask
=
decoder_attention_mask
[:,
:
-
1
],
labels
=
labels
[:,
1
:],
).
loss
return
{
"loss"
:
loss
}
def
generate
(
self
,
samples
,
do_sample
:
bool
=
False
,
temperature
:
float
=
0.2
,
top_p
:
float
=
0.95
):
pixel_values
=
samples
[
"image"
]
num_channels
=
pixel_values
.
shape
[
1
]
if
num_channels
==
1
:
pixel_values
=
pixel_values
.
repeat
(
1
,
3
,
1
,
1
)
kwargs
=
{}
if
do_sample
:
kwargs
[
"temperature"
]
=
temperature
kwargs
[
"top_p"
]
=
top_p
outputs
=
super
().
generate
(
pixel_values
=
pixel_values
,
max_new_tokens
=
self
.
tokenizer
.
tokenizer
.
model_max_length
,
# required
decoder_start_token_id
=
self
.
tokenizer
.
tokenizer
.
bos_token_id
,
do_sample
=
do_sample
,
**
kwargs
,
)
outputs
=
outputs
[:,
1
:].
cpu
().
numpy
()
pred_tokens
=
self
.
tokenizer
.
detokenize
(
outputs
)
pred_str
=
self
.
tokenizer
.
token2str
(
outputs
)
fixed_str
=
[
latex_rm_whitespace
(
s
)
for
s
in
pred_str
]
return
{
"pred_ids"
:
outputs
,
"pred_tokens"
:
pred_tokens
,
"pred_str"
:
pred_str
,
"fixed_str"
:
fixed_str
}
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/__init__.py
0 → 100644
View file @
1ec5d09d
from
.configuration_unimer_mbart
import
UnimerMBartConfig
from
.modeling_unimer_mbart
import
UnimerMBartModel
,
UnimerMBartForCausalLM
__all__
=
[
"UnimerMBartConfig"
,
"UnimerMBartModel"
,
"UnimerMBartForCausalLM"
,
]
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/configuration_unimer_mbart.py
0 → 100644
View file @
1ec5d09d
# coding=utf-8
# Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""UnimerMBART model configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
class
UnimerMBartConfig
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`MBartModel`]. It is used to instantiate an MBART
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the MBART
[facebook/mbart-large-cc25](https://huggingface.co/facebook/mbart-large-cc25) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50265):
Vocabulary size of the MBART model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`MBartModel`] or [`TFMBartModel`].
d_model (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
qk_squeeze (`int`, *optional*, defaults to 2):
Squeeze ratio for query/key's output dimension. See the [UniMERNet paper](https://arxiv.org/abs/2404.15254).
Squeeze Attention maps the query and key to a lower-dimensional space without excessive loss of information,
thereby accelerating the computation of attention.
encoder_layers (`int`, *optional*, defaults to 12):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 12):
Number of decoder layers.
encoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"silu"` and `"gelu_new"` are supported.
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for classifier.
max_position_embeddings (`int`, *optional*, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
init_std (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
for more details.
scale_embedding (`bool`, *optional*, defaults to `False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models)
forced_eos_token_id (`int`, *optional*, defaults to 2):
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
`eos_token_id`.
Example:
```python
>>> from transformers import MBartConfig, MBartModel
>>> # Initializing a MBART facebook/mbart-large-cc25 style configuration
>>> configuration = MBartConfig()
>>> # Initializing a model (with random weights) from the facebook/mbart-large-cc25 style configuration
>>> model = MBartModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"unimer-mbart"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
attribute_map
=
{
"num_attention_heads"
:
"encoder_attention_heads"
,
"hidden_size"
:
"d_model"
}
def
__init__
(
self
,
vocab_size
=
50265
,
max_position_embeddings
=
1024
,
encoder_layers
=
12
,
encoder_ffn_dim
=
4096
,
encoder_attention_heads
=
16
,
decoder_layers
=
12
,
decoder_ffn_dim
=
4096
,
decoder_attention_heads
=
16
,
encoder_layerdrop
=
0.0
,
decoder_layerdrop
=
0.0
,
use_cache
=
True
,
is_encoder_decoder
=
True
,
activation_function
=
"gelu"
,
d_model
=
1024
,
qk_squeeze
=
2
,
dropout
=
0.1
,
attention_dropout
=
0.0
,
activation_dropout
=
0.0
,
init_std
=
0.02
,
classifier_dropout
=
0.0
,
scale_embedding
=
False
,
pad_token_id
=
1
,
bos_token_id
=
0
,
eos_token_id
=
2
,
forced_eos_token_id
=
2
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
d_model
=
d_model
self
.
qk_squeeze
=
qk_squeeze
self
.
encoder_ffn_dim
=
encoder_ffn_dim
self
.
encoder_layers
=
encoder_layers
self
.
encoder_attention_heads
=
encoder_attention_heads
self
.
decoder_ffn_dim
=
decoder_ffn_dim
self
.
decoder_layers
=
decoder_layers
self
.
decoder_attention_heads
=
decoder_attention_heads
self
.
dropout
=
dropout
self
.
attention_dropout
=
attention_dropout
self
.
activation_dropout
=
activation_dropout
self
.
activation_function
=
activation_function
self
.
init_std
=
init_std
self
.
encoder_layerdrop
=
encoder_layerdrop
self
.
decoder_layerdrop
=
decoder_layerdrop
self
.
classifier_dropout
=
classifier_dropout
self
.
use_cache
=
use_cache
self
.
num_hidden_layers
=
encoder_layers
self
.
scale_embedding
=
scale_embedding
# scale factor will be sqrt(d_model) if True
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
is_encoder_decoder
=
is_encoder_decoder
,
forced_eos_token_id
=
forced_eos_token_id
,
**
kwargs
,
)
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/modeling_unimer_mbart.py
0 → 100644
View file @
1ec5d09d
This diff is collapsed.
Click to expand it.
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_mbart/tokenization_unimer_mbart.py
0 → 100644
View file @
1ec5d09d
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment