Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
1df26448
Unverified
Commit
1df26448
authored
Mar 19, 2025
by
Xiaomeng Zhao
Committed by
GitHub
Mar 19, 2025
Browse files
Merge pull request #6 from myhloli/remove-pillow
Remove pillow
parents
eae0e6d8
67b030eb
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
70 additions
and
246 deletions
+70
-246
magic_pdf/data/utils.py
magic_pdf/data/utils.py
+5
-9
magic_pdf/libs/pdf_image_tools.py
magic_pdf/libs/pdf_image_tools.py
+11
-6
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+5
-116
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+3
-28
magic_pdf/model/sub_modules/language_detection/utils.py
magic_pdf/model/sub_modules/language_detection/utils.py
+2
-4
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py
...f/model/sub_modules/language_detection/yolov11/YOLOv11.py
+24
-17
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
+3
-55
magic_pdf/model/sub_modules/model_utils.py
magic_pdf/model/sub_modules/model_utils.py
+17
-11
No files found.
magic_pdf/data/utils.py
View file @
1df26448
...
...
@@ -8,10 +8,8 @@ import fitz
import
numpy
as
np
from
loguru
import
logger
from
magic_pdf.utils.annotations
import
ImportPIL
@
ImportPIL
def
fitz_doc_to_image
(
doc
,
dpi
=
200
)
->
dict
:
"""Convert fitz.Document to image, Then convert the image to numpy array.
...
...
@@ -22,7 +20,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
Returns:
dict: {'img': numpy array, 'width': width, 'height': height }
"""
from
PIL
import
Image
mat
=
fitz
.
Matrix
(
dpi
/
72
,
dpi
/
72
)
pm
=
doc
.
get_pixmap
(
matrix
=
mat
,
alpha
=
False
)
...
...
@@ -30,16 +27,14 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
if
pm
.
width
>
4500
or
pm
.
height
>
4500
:
pm
=
doc
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
img
=
Image
.
frombytes
(
'RGB'
,
(
pm
.
width
,
pm
.
height
),
pm
.
samples
)
img
=
np
.
array
(
img
)
# Convert pixmap samples directly to numpy array
img
=
np
.
frombuffer
(
pm
.
samples
,
dtype
=
np
.
uint8
).
reshape
(
pm
.
height
,
pm
.
width
,
3
)
img_dict
=
{
'img'
:
img
,
'width'
:
pm
.
width
,
'height'
:
pm
.
height
}
return
img_dict
@
ImportPIL
def
load_images_from_pdf
(
pdf_bytes
:
bytes
,
dpi
=
200
,
start_page_id
=
0
,
end_page_id
=
None
)
->
list
:
from
PIL
import
Image
images
=
[]
with
fitz
.
open
(
'pdf'
,
pdf_bytes
)
as
doc
:
pdf_page_num
=
doc
.
page_count
...
...
@@ -62,8 +57,9 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
if
pm
.
width
>
4500
or
pm
.
height
>
4500
:
pm
=
page
.
get_pixmap
(
matrix
=
fitz
.
Matrix
(
1
,
1
),
alpha
=
False
)
img
=
Image
.
frombytes
(
'RGB'
,
(
pm
.
width
,
pm
.
height
),
pm
.
samples
)
img
=
np
.
array
(
img
)
# Convert pixmap samples directly to numpy array
img
=
np
.
frombuffer
(
pm
.
samples
,
dtype
=
np
.
uint8
).
reshape
(
pm
.
height
,
pm
.
width
,
3
)
img_dict
=
{
'img'
:
img
,
'width'
:
pm
.
width
,
'height'
:
pm
.
height
}
else
:
img_dict
=
{
'img'
:
[],
'width'
:
0
,
'height'
:
0
}
...
...
magic_pdf/libs/pdf_image_tools.py
View file @
1df26448
...
...
@@ -44,14 +44,19 @@ def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
# 截取图片
pix
=
page
.
get_pixmap
(
clip
=
rect
,
matrix
=
zoom
)
# 将字节数据转换为文件对象
image_file
=
BytesIO
(
pix
.
tobytes
(
output
=
'png'
))
# 使用 Pillow 打开图像
pil_image
=
Image
.
open
(
image_file
)
if
mode
==
"cv2"
:
image_result
=
cv2
.
cvtColor
(
np
.
asarray
(
pil_image
),
cv2
.
COLOR_RGB2BGR
)
# 直接转换为numpy数组供cv2使用
img_array
=
np
.
frombuffer
(
pix
.
samples
,
dtype
=
np
.
uint8
).
reshape
(
pix
.
height
,
pix
.
width
,
pix
.
n
)
# PyMuPDF使用RGB顺序,而cv2使用BGR顺序
if
pix
.
n
==
3
or
pix
.
n
==
4
:
image_result
=
cv2
.
cvtColor
(
img_array
,
cv2
.
COLOR_RGB2BGR
)
else
:
image_result
=
img_array
elif
mode
==
"pillow"
:
image_result
=
pil_image
# 将字节数据转换为文件对象
image_file
=
BytesIO
(
pix
.
tobytes
(
output
=
'png'
))
# 使用 Pillow 打开图像
image_result
=
Image
.
open
(
image_file
)
else
:
raise
ValueError
(
f
"mode:
{
mode
}
is not supported."
)
...
...
magic_pdf/model/batch_analyze.py
View file @
1df26448
import
time
import
cv2
import
numpy
as
np
import
torch
from
loguru
import
logger
from
PIL
import
Image
from
magic_pdf.config.constants
import
MODEL_NAME
# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
# from magic_pdf.data.dataset import Dataset
# from magic_pdf.libs.clean_memory import clean_memory
# from magic_pdf.libs.config_reader import get_device
# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from
magic_pdf.model.pdf_extract_kit
import
CustomPEKModel
from
magic_pdf.model.sub_modules.model_utils
import
(
clean_vram
,
crop_img
,
get_res_list_from_layout_res
)
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
# from magic_pdf.operators.models import InferenceResult
YOLO_LAYOUT_BASE_BATCH_SIZE
=
1
MFD_BASE_BATCH_SIZE
=
1
...
...
@@ -31,7 +23,6 @@ class BatchAnalyze:
def
__call__
(
self
,
images
:
list
)
->
list
:
images_layout_res
=
[]
layout_start_time
=
time
.
time
()
if
self
.
model
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
# layoutlmv3
...
...
@@ -41,36 +32,14 @@ class BatchAnalyze:
elif
self
.
model
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
# doclayout_yolo
layout_images
=
[]
modified_images
=
[]
for
image_index
,
image
in
enumerate
(
images
):
pil_img
=
Image
.
fromarray
(
image
)
# width, height = pil_img.size
# if height > width:
# input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
# new_image, useful_list = crop_img(
# input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
# )
# layout_images.append(new_image)
# modified_images.append([image_index, useful_list])
# else:
layout_images
.
append
(
pil_img
)
layout_images
.
append
(
image
)
images_layout_res
+=
self
.
model
.
layout_model
.
batch_predict
(
# layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
layout_images
,
YOLO_LAYOUT_BASE_BATCH_SIZE
)
for
image_index
,
useful_list
in
modified_images
:
for
res
in
images_layout_res
[
image_index
]:
for
i
in
range
(
len
(
res
[
'poly'
])):
if
i
%
2
==
0
:
res
[
'poly'
][
i
]
=
(
res
[
'poly'
][
i
]
-
useful_list
[
0
]
+
useful_list
[
2
]
)
else
:
res
[
'poly'
][
i
]
=
(
res
[
'poly'
][
i
]
-
useful_list
[
1
]
+
useful_list
[
3
]
)
logger
.
info
(
f
'layout time:
{
round
(
time
.
time
()
-
layout_start_time
,
2
)
}
, image num:
{
len
(
images
)
}
'
)
...
...
@@ -111,7 +80,7 @@ class BatchAnalyze:
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for
index
in
range
(
len
(
images
)):
layout_res
=
images_layout_res
[
index
]
pil_img
=
Image
.
fromarray
(
images
[
index
]
)
np_array_img
=
images
[
index
]
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
=
(
get_res_list_from_layout_res
(
layout_res
)
...
...
@@ -121,14 +90,14 @@ class BatchAnalyze:
# Process each area that requires OCR processing
for
res
in
ocr_res_list
:
new_image
,
useful_list
=
crop_img
(
res
,
pil
_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
res
,
np_array
_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
# OCR recognition
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
)
,
cv2
.
COLOR_RGB2BGR
)
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
if
self
.
model
.
apply_ocr
:
ocr_res
=
self
.
model
.
ocr_model
.
ocr
(
...
...
@@ -150,7 +119,7 @@ class BatchAnalyze:
if
self
.
model
.
apply_table
:
table_start
=
time
.
time
()
for
res
in
table_res_list
:
new_image
,
_
=
crop_img
(
res
,
pil
_img
)
new_image
,
_
=
crop_img
(
res
,
np_array
_img
)
single_table_start_time
=
time
.
time
()
html_code
=
None
if
self
.
model
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
...
...
@@ -197,83 +166,3 @@ class BatchAnalyze:
logger
.
info
(
f
'table time:
{
round
(
table_time
,
2
)
}
, image num:
{
table_count
}
'
)
return
images_layout_res
# def doc_batch_analyze(
# dataset: Dataset,
# ocr: bool = False,
# show_log: bool = False,
# start_page_id=0,
# end_page_id=None,
# lang=None,
# layout_model=None,
# formula_enable=None,
# table_enable=None,
# batch_ratio: int | None = None,
# ) -> InferenceResult:
# """Perform batch analysis on a document dataset.
#
# Args:
# dataset (Dataset): The dataset containing document pages to be analyzed.
# ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
# show_log (bool, optional): Flag to enable logging. Defaults to False.
# start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
# end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
# lang (str, optional): Language for OCR. Defaults to None.
# layout_model (optional): Layout model to be used for analysis. Defaults to None.
# formula_enable (optional): Flag to enable formula detection. Defaults to None.
# table_enable (optional): Flag to enable table detection. Defaults to None.
# batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
#
# Raises:
# CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
#
# Returns:
# InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
# """
#
# if not torch.cuda.is_available():
# raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
#
# lang = None if lang == '' else lang
# # TODO: auto detect batch size
# batch_ratio = 1 if batch_ratio is None else batch_ratio
# end_page_id = end_page_id if end_page_id else len(dataset)
#
# model_manager = ModelSingleton()
# custom_model: CustomPEKModel = model_manager.get_model(
# ocr, show_log, lang, layout_model, formula_enable, table_enable
# )
# batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
#
# model_json = []
#
# # batch analyze
# images = []
# for index in range(len(dataset)):
# if start_page_id <= index <= end_page_id:
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# images.append(img_dict['img'])
# analyze_result = batch_model(images)
#
# for index in range(len(dataset)):
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# page_width = img_dict['width']
# page_height = img_dict['height']
# if start_page_id <= index <= end_page_id:
# result = analyze_result.pop(0)
# else:
# result = []
#
# page_info = {'page_no': index, 'height': page_height, 'width': page_width}
# page_dict = {'layout_dets': result, 'page_info': page_info}
# model_json.append(page_dict)
#
# # TODO: clean memory when gpu memory is not enough
# clean_memory_start_time = time.time()
# clean_memory(get_device())
# logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
#
# return InferenceResult(model_json, dataset)
magic_pdf/model/pdf_extract_kit.py
View file @
1df26448
...
...
@@ -3,11 +3,9 @@ import os
import
time
import
cv2
import
numpy
as
np
import
torch
import
yaml
from
loguru
import
logger
from
PIL
import
Image
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
...
...
@@ -174,11 +172,6 @@ class CustomPEKModel:
logger
.
info
(
'DocAnalysis init done!'
)
def
__call__
(
self
,
image
):
pil_img
=
Image
.
fromarray
(
image
)
width
,
height
=
pil_img
.
size
# logger.info(f'width: {width}, height: {height}')
# layout检测
layout_start
=
time
.
time
()
layout_res
=
[]
...
...
@@ -186,24 +179,6 @@ class CustomPEKModel:
# layoutlmv3
layout_res
=
self
.
layout_model
(
image
,
ignore_catids
=
[])
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
# doclayout_yolo
# if height > width:
# input_res = {"poly":[0,0,width,0,width,height,0,height]}
# new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
# paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# layout_res = self.layout_model.predict(new_image)
# for res in layout_res:
# p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
# p1 = p1 - paste_x + xmin
# p2 = p2 - paste_y + ymin
# p3 = p3 - paste_x + xmin
# p4 = p4 - paste_y + ymin
# p5 = p5 - paste_x + xmin
# p6 = p6 - paste_y + ymin
# p7 = p7 - paste_x + xmin
# p8 = p8 - paste_y + ymin
# res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
# else:
layout_res
=
self
.
layout_model
.
predict
(
image
)
layout_cost
=
round
(
time
.
time
()
-
layout_start
,
2
)
...
...
@@ -234,11 +209,11 @@ class CustomPEKModel:
ocr_start
=
time
.
time
()
# Process each area that requires OCR processing
for
res
in
ocr_res_list
:
new_image
,
useful_list
=
crop_img
(
res
,
pil_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
new_image
,
useful_list
=
crop_img
(
res
,
image
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
# OCR recognition
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
)
,
cv2
.
COLOR_RGB2BGR
)
new_image
=
cv2
.
cvtColor
(
new_image
,
cv2
.
COLOR_RGB2BGR
)
if
self
.
apply_ocr
:
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
...
...
@@ -260,7 +235,7 @@ class CustomPEKModel:
if
self
.
apply_table
:
table_start
=
time
.
time
()
for
res
in
table_res_list
:
new_image
,
_
=
crop_img
(
res
,
pil_img
)
new_image
,
_
=
crop_img
(
res
,
image
)
single_table_start_time
=
time
.
time
()
html_code
=
None
if
self
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
...
...
magic_pdf/model/sub_modules/language_detection/utils.py
View file @
1df26448
...
...
@@ -3,8 +3,6 @@ import os
from
pathlib
import
Path
import
yaml
from
PIL
import
Image
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
from
magic_pdf.config.constants
import
MODEL_NAME
...
...
@@ -42,7 +40,7 @@ def get_text_images(simple_images):
)
text_images
=
[]
for
simple_image
in
simple_images
:
image
=
Image
.
fromarray
(
simple_image
[
'img'
]
)
image
=
simple_image
[
'img'
]
layout_res
=
temp_layout_model
.
predict
(
image
)
# 给textblock截图
for
res
in
layout_res
:
...
...
@@ -51,7 +49,7 @@ def get_text_images(simple_images):
# 初步清洗(宽和高都小于100)
if
x2
-
x1
<
100
and
y2
-
y1
<
100
:
continue
text_images
.
append
(
image
.
crop
((
x1
,
y1
,
x2
,
y2
))
)
text_images
.
append
(
image
[
y1
:
y2
,
x1
:
x2
]
)
return
text_images
...
...
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py
View file @
1df26448
...
...
@@ -2,9 +2,9 @@
import
time
from
collections
import
Counter
from
uuid
import
uuid4
import
cv2
import
numpy
as
np
import
torch
from
PIL
import
Image
from
loguru
import
logger
from
ultralytics
import
YOLO
...
...
@@ -29,7 +29,7 @@ def split_images(image, result_images=None):
if
result_images
is
None
:
result_images
=
[]
width
,
height
=
image
.
s
ize
height
,
width
=
image
.
s
hape
[:
2
]
long_side
=
max
(
width
,
height
)
# 获取较长边长度
if
long_side
<=
400
:
...
...
@@ -45,7 +45,7 @@ def split_images(image, result_images=None):
if
x
+
new_long_side
>
width
:
continue
box
=
(
x
,
0
,
x
+
new_long_side
,
height
)
sub_image
=
image
.
crop
(
box
)
sub_image
=
image
[
0
:
height
,
x
:
x
+
new_long_side
]
sub_images
.
append
(
sub_image
)
else
:
# 如果高度是较长边
for
y
in
range
(
0
,
height
,
new_long_side
):
...
...
@@ -53,7 +53,7 @@ def split_images(image, result_images=None):
if
y
+
new_long_side
>
height
:
continue
box
=
(
0
,
y
,
width
,
y
+
new_long_side
)
sub_image
=
image
.
crop
(
box
)
sub_image
=
image
[
y
:
y
+
new_long_side
,
0
:
width
]
sub_images
.
append
(
sub_image
)
for
sub_image
in
sub_images
:
...
...
@@ -64,24 +64,32 @@ def split_images(image, result_images=None):
def
resize_images_to_224
(
image
):
"""
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小。
Works directly with NumPy arrays.
"""
try
:
width
,
height
=
image
.
size
height
,
width
=
image
.
shape
[:
2
]
if
width
<
224
or
height
<
224
:
new_image
=
Image
.
new
(
'RGB'
,
(
224
,
224
),
(
0
,
0
,
0
))
paste_x
=
(
224
-
width
)
//
2
paste_y
=
(
224
-
height
)
//
2
new_image
.
paste
(
image
,
(
paste_x
,
paste_y
))
# Create black background
new_image
=
np
.
zeros
((
224
,
224
,
3
),
dtype
=
np
.
uint8
)
# Calculate paste position (ensure they're not negative)
paste_x
=
max
(
0
,
(
224
-
width
)
//
2
)
paste_y
=
max
(
0
,
(
224
-
height
)
//
2
)
# Make sure we don't exceed the boundaries of new_image
paste_width
=
min
(
width
,
224
)
paste_height
=
min
(
height
,
224
)
# Paste original image onto black background
new_image
[
paste_y
:
paste_y
+
paste_height
,
paste_x
:
paste_x
+
paste_width
]
=
image
[:
paste_height
,
:
paste_width
]
image
=
new_image
else
:
image
=
image
.
resize
((
224
,
224
),
Image
.
Resampling
.
LANCZOS
)
# Resize using cv2
image
=
cv2
.
resize
(
image
,
(
224
,
224
),
interpolation
=
cv2
.
INTER_LANCZOS4
)
# uuid = str(uuid4())
# image.save(f"/tmp/{uuid}.jpg")
return
image
except
Exception
as
e
:
logger
.
exception
(
e
)
logger
.
exception
(
f
"Error in resize_images_to_224:
{
e
}
"
)
return
None
class
YOLOv11LangDetModel
(
object
):
...
...
@@ -96,8 +104,7 @@ class YOLOv11LangDetModel(object):
def
do_detect
(
self
,
images
:
list
):
all_images
=
[]
for
image
in
images
:
width
,
height
=
image
.
size
# logger.info(f"image size: {width} x {height}")
height
,
width
=
image
.
shape
[:
2
]
if
width
<
100
and
height
<
100
:
continue
temp_images
=
split_images
(
image
)
...
...
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py
View file @
1df26448
...
...
@@ -4,7 +4,6 @@ import re
import
torch
import
unimernet.tasks
as
tasks
from
PIL
import
Image
from
torch.utils.data
import
DataLoader
,
Dataset
from
torchvision
import
transforms
from
unimernet.common.config
import
Config
...
...
@@ -19,16 +18,6 @@ class MathDataset(Dataset):
def
__len__
(
self
):
return
len
(
self
.
image_paths
)
def
__getitem__
(
self
,
idx
):
# if not pil image, then convert to pil image
if
isinstance
(
self
.
image_paths
[
idx
],
str
):
raw_image
=
Image
.
open
(
self
.
image_paths
[
idx
])
else
:
raw_image
=
self
.
image_paths
[
idx
]
if
self
.
transform
:
image
=
self
.
transform
(
raw_image
)
return
image
def
latex_rm_whitespace
(
s
:
str
):
"""Remove unnecessary whitespace from LaTeX code."""
...
...
@@ -84,8 +73,7 @@ class UnimernetModel(object):
"latex"
:
""
,
}
formula_list
.
append
(
new_item
)
pil_img
=
Image
.
fromarray
(
image
)
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
ymax
))
bbox_img
=
image
[
ymin
:
ymax
,
xmin
:
xmax
]
mf_image_list
.
append
(
bbox_img
)
dataset
=
MathDataset
(
mf_image_list
,
transform
=
self
.
mfr_transform
)
...
...
@@ -100,46 +88,6 @@ class UnimernetModel(object):
res
[
"latex"
]
=
latex_rm_whitespace
(
latex
)
return
formula_list
# def batch_predict(
# self, images_mfd_res: list, images: list, batch_size: int = 64
# ) -> list:
# images_formula_list = []
# mf_image_list = []
# backfill_list = []
# for image_index in range(len(images_mfd_res)):
# mfd_res = images_mfd_res[image_index]
# pil_img = Image.fromarray(images[image_index])
# formula_list = []
#
# for xyxy, conf, cla in zip(
# mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
# ):
# xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
# new_item = {
# "category_id": 13 + int(cla.item()),
# "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
# "score": round(float(conf.item()), 2),
# "latex": "",
# }
# formula_list.append(new_item)
# bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
# mf_image_list.append(bbox_img)
#
# images_formula_list.append(formula_list)
# backfill_list += formula_list
#
# dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# mfr_res = []
# for mf_img in dataloader:
# mf_img = mf_img.to(self.device)
# with torch.no_grad():
# output = self.model.generate({"image": mf_img})
# mfr_res.extend(output["pred_str"])
# for res, latex in zip(backfill_list, mfr_res):
# res["latex"] = latex_rm_whitespace(latex)
# return images_formula_list
def
batch_predict
(
self
,
images_mfd_res
:
list
,
images
:
list
,
batch_size
:
int
=
64
)
->
list
:
images_formula_list
=
[]
mf_image_list
=
[]
...
...
@@ -149,7 +97,7 @@ class UnimernetModel(object):
# Collect images with their original indices
for
image_index
in
range
(
len
(
images_mfd_res
)):
mfd_res
=
images_mfd_res
[
image_index
]
pil_img
=
Image
.
fromarray
(
images
[
image_index
]
)
np_array_image
=
images
[
image_index
]
formula_list
=
[]
for
idx
,
(
xyxy
,
conf
,
cla
)
in
enumerate
(
zip
(
...
...
@@ -163,7 +111,7 @@ class UnimernetModel(object):
"latex"
:
""
,
}
formula_list
.
append
(
new_item
)
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
y
max
))
bbox_img
=
np_array_image
[
ymin
:
ymax
,
xmin
:
x
max
]
area
=
(
xmax
-
xmin
)
*
(
ymax
-
ymin
)
curr_idx
=
len
(
mf_image_list
)
...
...
magic_pdf/model/sub_modules/model_utils.py
View file @
1df26448
import
time
import
torch
from
PIL
import
Image
from
loguru
import
logger
import
numpy
as
np
from
magic_pdf.libs.clean_memory
import
clean_memory
def
crop_img
(
input_res
,
input_pil_img
,
crop_paste_x
=
0
,
crop_paste_y
=
0
):
def
crop_img
(
input_res
,
input_np_img
,
crop_paste_x
=
0
,
crop_paste_y
=
0
):
crop_xmin
,
crop_ymin
=
int
(
input_res
[
'poly'
][
0
]),
int
(
input_res
[
'poly'
][
1
])
crop_xmax
,
crop_ymax
=
int
(
input_res
[
'poly'
][
4
]),
int
(
input_res
[
'poly'
][
5
])
# Create a white background with an additional width and height of 50
# Calculate new dimensions
crop_new_width
=
crop_xmax
-
crop_xmin
+
crop_paste_x
*
2
crop_new_height
=
crop_ymax
-
crop_ymin
+
crop_paste_y
*
2
return_image
=
Image
.
new
(
'RGB'
,
(
crop_new_width
,
crop_new_height
),
'white'
)
# Crop image
crop_box
=
(
crop_xmin
,
crop_ymin
,
crop_xmax
,
crop_ymax
)
cropped_img
=
input_pil_img
.
crop
(
crop_box
)
return_image
.
paste
(
cropped_img
,
(
crop_paste_x
,
crop_paste_y
))
return_list
=
[
crop_paste_x
,
crop_paste_y
,
crop_xmin
,
crop_ymin
,
crop_xmax
,
crop_ymax
,
crop_new_width
,
crop_new_height
]
# Create a white background array
return_image
=
np
.
ones
((
crop_new_height
,
crop_new_width
,
3
),
dtype
=
np
.
uint8
)
*
255
# Crop the original image using numpy slicing
cropped_img
=
input_np_img
[
crop_ymin
:
crop_ymax
,
crop_xmin
:
crop_xmax
]
# Paste the cropped image onto the white background
return_image
[
crop_paste_y
:
crop_paste_y
+
(
crop_ymax
-
crop_ymin
),
crop_paste_x
:
crop_paste_x
+
(
crop_xmax
-
crop_xmin
)]
=
cropped_img
return_list
=
[
crop_paste_x
,
crop_paste_y
,
crop_xmin
,
crop_ymin
,
crop_xmax
,
crop_ymax
,
crop_new_width
,
crop_new_height
]
return
return_image
,
return_list
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment