Commit 1b34f7e4 authored by myhloli's avatar myhloli
Browse files

refactor(magic_pdf): replace PIL with NumPy for image processing

- Remove PIL usage across multiple files
- Convert image processing functions to use NumPy arrays
- Update crop_img function to work with NumPy arrays
- Modify image loading and resizing to use NumPy and OpenCV
- Clean up unused imports and comments related to PIL
parent cf15c065
......@@ -3,10 +3,8 @@ import fitz
import numpy as np
from loguru import logger
from magic_pdf.utils.annotations import ImportPIL
@ImportPIL
def fitz_doc_to_image(doc, dpi=200) -> dict:
"""Convert fitz.Document to image, Then convert the image to numpy array.
......@@ -17,7 +15,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
Returns:
dict: {'img': numpy array, 'width': width, 'height': height }
"""
from PIL import Image
mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = doc.get_pixmap(matrix=mat, alpha=False)
......@@ -25,8 +22,8 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
if pm.width > 4500 or pm.height > 4500:
pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
img = np.array(img)
# Convert pixmap samples directly to numpy array
img = np.frombuffer(pm.samples, dtype=np.uint8).reshape(pm.height, pm.width, 3)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
......@@ -34,7 +31,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
@ImportPIL
def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
from PIL import Image
images = []
with fitz.open('pdf', pdf_bytes) as doc:
pdf_page_num = doc.page_count
......@@ -57,8 +53,9 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
if pm.width > 4500 or pm.height > 4500:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
img = np.array(img)
# Convert pixmap samples directly to numpy array
img = np.frombuffer(pm.samples, dtype=np.uint8).reshape(pm.height, pm.width, 3)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
else:
img_dict = {'img': [], 'width': 0, 'height': 0}
......
......@@ -44,14 +44,19 @@ def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
# 截取图片
pix = page.get_pixmap(clip=rect, matrix=zoom)
if mode == "cv2":
# 直接转换为numpy数组供cv2使用
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
# PyMuPDF使用RGB顺序,而cv2使用BGR顺序
if pix.n == 3 or pix.n == 4:
image_result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
else:
image_result = img_array
elif mode == "pillow":
# 将字节数据转换为文件对象
image_file = BytesIO(pix.tobytes(output='png'))
# 使用 Pillow 打开图像
pil_image = Image.open(image_file)
if mode == "cv2":
image_result = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2BGR)
elif mode == "pillow":
image_result = pil_image
image_result = Image.open(image_file)
else:
raise ValueError(f"mode: {mode} is not supported.")
......
import time
import cv2
import numpy as np
import torch
from loguru import logger
from PIL import Image
from magic_pdf.config.constants import MODEL_NAME
# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
# from magic_pdf.data.dataset import Dataset
# from magic_pdf.libs.clean_memory import clean_memory
# from magic_pdf.libs.config_reader import get_device
# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
# from magic_pdf.operators.models import InferenceResult
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
MFD_BASE_BATCH_SIZE = 1
......@@ -31,7 +23,6 @@ class BatchAnalyze:
def __call__(self, images: list) -> list:
images_layout_res = []
layout_start_time = time.time()
if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3
......@@ -41,36 +32,14 @@ class BatchAnalyze:
elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
layout_images = []
modified_images = []
for image_index, image in enumerate(images):
pil_img = Image.fromarray(image)
# width, height = pil_img.size
# if height > width:
# input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
# new_image, useful_list = crop_img(
# input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
# )
# layout_images.append(new_image)
# modified_images.append([image_index, useful_list])
# else:
layout_images.append(pil_img)
layout_images.append(image)
images_layout_res += self.model.layout_model.batch_predict(
# layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
)
for image_index, useful_list in modified_images:
for res in images_layout_res[image_index]:
for i in range(len(res['poly'])):
if i % 2 == 0:
res['poly'][i] = (
res['poly'][i] - useful_list[0] + useful_list[2]
)
else:
res['poly'][i] = (
res['poly'][i] - useful_list[1] + useful_list[3]
)
logger.info(
f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
)
......@@ -111,7 +80,7 @@ class BatchAnalyze:
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for index in range(len(images)):
layout_res = images_layout_res[index]
pil_img = Image.fromarray(images[index])
np_array_img = images[index]
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
......@@ -121,14 +90,14 @@ class BatchAnalyze:
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(
res, pil_img, crop_paste_x=50, crop_paste_y=50
res, np_array_img, crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
single_page_mfdetrec_res, useful_list
)
# OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
if self.model.apply_ocr:
ocr_res = self.model.ocr_model.ocr(
......@@ -150,7 +119,7 @@ class BatchAnalyze:
if self.model.apply_table:
table_start = time.time()
for res in table_res_list:
new_image, _ = crop_img(res, pil_img)
new_image, _ = crop_img(res, np_array_img)
single_table_start_time = time.time()
html_code = None
if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
......@@ -197,83 +166,3 @@ class BatchAnalyze:
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
return images_layout_res
# def doc_batch_analyze(
# dataset: Dataset,
# ocr: bool = False,
# show_log: bool = False,
# start_page_id=0,
# end_page_id=None,
# lang=None,
# layout_model=None,
# formula_enable=None,
# table_enable=None,
# batch_ratio: int | None = None,
# ) -> InferenceResult:
# """Perform batch analysis on a document dataset.
#
# Args:
# dataset (Dataset): The dataset containing document pages to be analyzed.
# ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
# show_log (bool, optional): Flag to enable logging. Defaults to False.
# start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
# end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
# lang (str, optional): Language for OCR. Defaults to None.
# layout_model (optional): Layout model to be used for analysis. Defaults to None.
# formula_enable (optional): Flag to enable formula detection. Defaults to None.
# table_enable (optional): Flag to enable table detection. Defaults to None.
# batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
#
# Raises:
# CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
#
# Returns:
# InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
# """
#
# if not torch.cuda.is_available():
# raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
#
# lang = None if lang == '' else lang
# # TODO: auto detect batch size
# batch_ratio = 1 if batch_ratio is None else batch_ratio
# end_page_id = end_page_id if end_page_id else len(dataset)
#
# model_manager = ModelSingleton()
# custom_model: CustomPEKModel = model_manager.get_model(
# ocr, show_log, lang, layout_model, formula_enable, table_enable
# )
# batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
#
# model_json = []
#
# # batch analyze
# images = []
# for index in range(len(dataset)):
# if start_page_id <= index <= end_page_id:
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# images.append(img_dict['img'])
# analyze_result = batch_model(images)
#
# for index in range(len(dataset)):
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# page_width = img_dict['width']
# page_height = img_dict['height']
# if start_page_id <= index <= end_page_id:
# result = analyze_result.pop(0)
# else:
# result = []
#
# page_info = {'page_no': index, 'height': page_height, 'width': page_width}
# page_dict = {'layout_dets': result, 'page_info': page_info}
# model_json.append(page_dict)
#
# # TODO: clean memory when gpu memory is not enough
# clean_memory_start_time = time.time()
# clean_memory(get_device())
# logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
#
# return InferenceResult(model_json, dataset)
......@@ -3,11 +3,9 @@ import os
import time
import cv2
import numpy as np
import torch
import yaml
from loguru import logger
from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
......@@ -174,11 +172,6 @@ class CustomPEKModel:
logger.info('DocAnalysis init done!')
def __call__(self, image):
pil_img = Image.fromarray(image)
width, height = pil_img.size
# logger.info(f'width: {width}, height: {height}')
# layout检测
layout_start = time.time()
layout_res = []
......@@ -186,24 +179,6 @@ class CustomPEKModel:
# layoutlmv3
layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
# if height > width:
# input_res = {"poly":[0,0,width,0,width,height,0,height]}
# new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
# paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# layout_res = self.layout_model.predict(new_image)
# for res in layout_res:
# p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
# p1 = p1 - paste_x + xmin
# p2 = p2 - paste_y + ymin
# p3 = p3 - paste_x + xmin
# p4 = p4 - paste_y + ymin
# p5 = p5 - paste_x + xmin
# p6 = p6 - paste_y + ymin
# p7 = p7 - paste_x + xmin
# p8 = p8 - paste_y + ymin
# res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
# else:
layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2)
......@@ -234,11 +209,11 @@ class CustomPEKModel:
ocr_start = time.time()
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
new_image, useful_list = crop_img(res, image, crop_paste_x=50, crop_paste_y=50)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
# OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
if self.apply_ocr:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
......@@ -260,7 +235,7 @@ class CustomPEKModel:
if self.apply_table:
table_start = time.time()
for res in table_res_list:
new_image, _ = crop_img(res, pil_img)
new_image, _ = crop_img(res, image)
single_table_start_time = time.time()
html_code = None
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
......
......@@ -3,8 +3,6 @@ import os
from pathlib import Path
import yaml
from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from magic_pdf.config.constants import MODEL_NAME
......@@ -42,7 +40,7 @@ def get_text_images(simple_images):
)
text_images = []
for simple_image in simple_images:
image = Image.fromarray(simple_image['img'])
image = simple_image['img']
layout_res = temp_layout_model.predict(image)
# 给textblock截图
for res in layout_res:
......@@ -51,7 +49,7 @@ def get_text_images(simple_images):
# 初步清洗(宽和高都小于100)
if x2 - x1 < 100 and y2 - y1 < 100:
continue
text_images.append(image.crop((x1, y1, x2, y2)))
text_images.append(image[y1:y2, x1:x2])
return text_images
......
......@@ -3,8 +3,8 @@ import time
from collections import Counter
from uuid import uuid4
import numpy as np
import torch
from PIL import Image
from loguru import logger
from ultralytics import YOLO
......@@ -64,21 +64,32 @@ def split_images(image, result_images=None):
def resize_images_to_224(image):
"""
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小。
Works directly with NumPy arrays.
"""
try:
width, height = image.size
# Handle numpy array directly
if len(image.shape) == 3: # Color image
height, width, channels = image.shape
else: # Grayscale image
height, width = image.shape
image = np.stack([image] * 3, axis=2) # Convert to RGB
if width < 224 or height < 224:
new_image = Image.new('RGB', (224, 224), (0, 0, 0))
# Create black background
new_image = np.zeros((224, 224, 3), dtype=np.uint8)
# Calculate paste position
paste_x = (224 - width) // 2
paste_y = (224 - height) // 2
new_image.paste(image, (paste_x, paste_y))
# Paste original image onto black background
new_image[paste_y:paste_y + height, paste_x:paste_x + width] = image
image = new_image
else:
image = image.resize((224, 224), Image.Resampling.LANCZOS)
# Resize using cv2 functionality or numpy interpolation
# Method 1: Using cv2 (preferred for better quality)
import cv2
image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LANCZOS4)
# uuid = str(uuid4())
# image.save(f"/tmp/{uuid}.jpg")
return image
except Exception as e:
logger.exception(e)
......@@ -96,8 +107,7 @@ class YOLOv11LangDetModel(object):
def do_detect(self, images: list):
all_images = []
for image in images:
width, height = image.size
# logger.info(f"image size: {width} x {height}")
height, width = image.shape[:2]
if width < 100 and height < 100:
continue
temp_images = split_images(image)
......
......@@ -4,7 +4,6 @@ import re
import torch
import unimernet.tasks as tasks
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from unimernet.common.config import Config
......@@ -100,45 +99,6 @@ class UnimernetModel(object):
res["latex"] = latex_rm_whitespace(latex)
return formula_list
# def batch_predict(
# self, images_mfd_res: list, images: list, batch_size: int = 64
# ) -> list:
# images_formula_list = []
# mf_image_list = []
# backfill_list = []
# for image_index in range(len(images_mfd_res)):
# mfd_res = images_mfd_res[image_index]
# pil_img = Image.fromarray(images[image_index])
# formula_list = []
#
# for xyxy, conf, cla in zip(
# mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
# ):
# xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
# new_item = {
# "category_id": 13 + int(cla.item()),
# "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
# "score": round(float(conf.item()), 2),
# "latex": "",
# }
# formula_list.append(new_item)
# bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
# mf_image_list.append(bbox_img)
#
# images_formula_list.append(formula_list)
# backfill_list += formula_list
#
# dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# mfr_res = []
# for mf_img in dataloader:
# mf_img = mf_img.to(self.device)
# with torch.no_grad():
# output = self.model.generate({"image": mf_img})
# mfr_res.extend(output["pred_str"])
# for res, latex in zip(backfill_list, mfr_res):
# res["latex"] = latex_rm_whitespace(latex)
# return images_formula_list
def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
images_formula_list = []
......@@ -149,7 +109,7 @@ class UnimernetModel(object):
# Collect images with their original indices
for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index]
pil_img = Image.fromarray(images[image_index])
np_array_image = images[image_index]
formula_list = []
for idx, (xyxy, conf, cla) in enumerate(zip(
......@@ -163,7 +123,7 @@ class UnimernetModel(object):
"latex": "",
}
formula_list.append(new_item)
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
bbox_img = np_array_image[ymin:ymax, xmin:xmax]
area = (xmax - xmin) * (ymax - ymin)
curr_idx = len(mf_image_list)
......
import time
import torch
from PIL import Image
from loguru import logger
import numpy as np
from magic_pdf.libs.clean_memory import clean_memory
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
# Create a white background with an additional width and height of 50
# Calculate new dimensions
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
# Crop image
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
cropped_img = input_pil_img.crop(crop_box)
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
# Create a white background array
return_image = np.ones((crop_new_height, crop_new_width, 3), dtype=np.uint8) * 255
# Crop the original image using numpy slicing
cropped_img = input_np_img[crop_ymin:crop_ymax, crop_xmin:crop_xmax]
# Paste the cropped image onto the white background
return_image[crop_paste_y:crop_paste_y + (crop_ymax - crop_ymin),
crop_paste_x:crop_paste_x + (crop_xmax - crop_xmin)] = cropped_img
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width,
crop_new_height]
return return_image, return_list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment