Unverified Commit 1df26448 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #6 from myhloli/remove-pillow

Remove pillow
parents eae0e6d8 67b030eb
......@@ -8,10 +8,8 @@ import fitz
import numpy as np
from loguru import logger
from magic_pdf.utils.annotations import ImportPIL
@ImportPIL
def fitz_doc_to_image(doc, dpi=200) -> dict:
"""Convert fitz.Document to image, Then convert the image to numpy array.
......@@ -22,7 +20,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
Returns:
dict: {'img': numpy array, 'width': width, 'height': height }
"""
from PIL import Image
mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = doc.get_pixmap(matrix=mat, alpha=False)
......@@ -30,16 +27,14 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
if pm.width > 4500 or pm.height > 4500:
pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
img = np.array(img)
# Convert pixmap samples directly to numpy array
img = np.frombuffer(pm.samples, dtype=np.uint8).reshape(pm.height, pm.width, 3)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
return img_dict
@ImportPIL
def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
from PIL import Image
images = []
with fitz.open('pdf', pdf_bytes) as doc:
pdf_page_num = doc.page_count
......@@ -62,8 +57,9 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
if pm.width > 4500 or pm.height > 4500:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
img = np.array(img)
# Convert pixmap samples directly to numpy array
img = np.frombuffer(pm.samples, dtype=np.uint8).reshape(pm.height, pm.width, 3)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
else:
img_dict = {'img': [], 'width': 0, 'height': 0}
......
......@@ -44,14 +44,19 @@ def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
# 截取图片
pix = page.get_pixmap(clip=rect, matrix=zoom)
# 将字节数据转换为文件对象
image_file = BytesIO(pix.tobytes(output='png'))
# 使用 Pillow 打开图像
pil_image = Image.open(image_file)
if mode == "cv2":
image_result = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2BGR)
# 直接转换为numpy数组供cv2使用
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
# PyMuPDF使用RGB顺序,而cv2使用BGR顺序
if pix.n == 3 or pix.n == 4:
image_result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
else:
image_result = img_array
elif mode == "pillow":
image_result = pil_image
# 将字节数据转换为文件对象
image_file = BytesIO(pix.tobytes(output='png'))
# 使用 Pillow 打开图像
image_result = Image.open(image_file)
else:
raise ValueError(f"mode: {mode} is not supported.")
......
import time
import cv2
import numpy as np
import torch
from loguru import logger
from PIL import Image
from magic_pdf.config.constants import MODEL_NAME
# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
# from magic_pdf.data.dataset import Dataset
# from magic_pdf.libs.clean_memory import clean_memory
# from magic_pdf.libs.config_reader import get_device
# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
# from magic_pdf.operators.models import InferenceResult
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
MFD_BASE_BATCH_SIZE = 1
......@@ -31,7 +23,6 @@ class BatchAnalyze:
def __call__(self, images: list) -> list:
images_layout_res = []
layout_start_time = time.time()
if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3
......@@ -41,36 +32,14 @@ class BatchAnalyze:
elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
layout_images = []
modified_images = []
for image_index, image in enumerate(images):
pil_img = Image.fromarray(image)
# width, height = pil_img.size
# if height > width:
# input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
# new_image, useful_list = crop_img(
# input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
# )
# layout_images.append(new_image)
# modified_images.append([image_index, useful_list])
# else:
layout_images.append(pil_img)
layout_images.append(image)
images_layout_res += self.model.layout_model.batch_predict(
# layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
)
for image_index, useful_list in modified_images:
for res in images_layout_res[image_index]:
for i in range(len(res['poly'])):
if i % 2 == 0:
res['poly'][i] = (
res['poly'][i] - useful_list[0] + useful_list[2]
)
else:
res['poly'][i] = (
res['poly'][i] - useful_list[1] + useful_list[3]
)
logger.info(
f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
)
......@@ -111,7 +80,7 @@ class BatchAnalyze:
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for index in range(len(images)):
layout_res = images_layout_res[index]
pil_img = Image.fromarray(images[index])
np_array_img = images[index]
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
......@@ -121,14 +90,14 @@ class BatchAnalyze:
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(
res, pil_img, crop_paste_x=50, crop_paste_y=50
res, np_array_img, crop_paste_x=50, crop_paste_y=50
)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
single_page_mfdetrec_res, useful_list
)
# OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
if self.model.apply_ocr:
ocr_res = self.model.ocr_model.ocr(
......@@ -150,7 +119,7 @@ class BatchAnalyze:
if self.model.apply_table:
table_start = time.time()
for res in table_res_list:
new_image, _ = crop_img(res, pil_img)
new_image, _ = crop_img(res, np_array_img)
single_table_start_time = time.time()
html_code = None
if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
......@@ -197,83 +166,3 @@ class BatchAnalyze:
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
return images_layout_res
# def doc_batch_analyze(
# dataset: Dataset,
# ocr: bool = False,
# show_log: bool = False,
# start_page_id=0,
# end_page_id=None,
# lang=None,
# layout_model=None,
# formula_enable=None,
# table_enable=None,
# batch_ratio: int | None = None,
# ) -> InferenceResult:
# """Perform batch analysis on a document dataset.
#
# Args:
# dataset (Dataset): The dataset containing document pages to be analyzed.
# ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
# show_log (bool, optional): Flag to enable logging. Defaults to False.
# start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
# end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
# lang (str, optional): Language for OCR. Defaults to None.
# layout_model (optional): Layout model to be used for analysis. Defaults to None.
# formula_enable (optional): Flag to enable formula detection. Defaults to None.
# table_enable (optional): Flag to enable table detection. Defaults to None.
# batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
#
# Raises:
# CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
#
# Returns:
# InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
# """
#
# if not torch.cuda.is_available():
# raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
#
# lang = None if lang == '' else lang
# # TODO: auto detect batch size
# batch_ratio = 1 if batch_ratio is None else batch_ratio
# end_page_id = end_page_id if end_page_id else len(dataset)
#
# model_manager = ModelSingleton()
# custom_model: CustomPEKModel = model_manager.get_model(
# ocr, show_log, lang, layout_model, formula_enable, table_enable
# )
# batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
#
# model_json = []
#
# # batch analyze
# images = []
# for index in range(len(dataset)):
# if start_page_id <= index <= end_page_id:
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# images.append(img_dict['img'])
# analyze_result = batch_model(images)
#
# for index in range(len(dataset)):
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# page_width = img_dict['width']
# page_height = img_dict['height']
# if start_page_id <= index <= end_page_id:
# result = analyze_result.pop(0)
# else:
# result = []
#
# page_info = {'page_no': index, 'height': page_height, 'width': page_width}
# page_dict = {'layout_dets': result, 'page_info': page_info}
# model_json.append(page_dict)
#
# # TODO: clean memory when gpu memory is not enough
# clean_memory_start_time = time.time()
# clean_memory(get_device())
# logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
#
# return InferenceResult(model_json, dataset)
......@@ -3,11 +3,9 @@ import os
import time
import cv2
import numpy as np
import torch
import yaml
from loguru import logger
from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
......@@ -174,11 +172,6 @@ class CustomPEKModel:
logger.info('DocAnalysis init done!')
def __call__(self, image):
pil_img = Image.fromarray(image)
width, height = pil_img.size
# logger.info(f'width: {width}, height: {height}')
# layout检测
layout_start = time.time()
layout_res = []
......@@ -186,24 +179,6 @@ class CustomPEKModel:
# layoutlmv3
layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
# if height > width:
# input_res = {"poly":[0,0,width,0,width,height,0,height]}
# new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
# paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# layout_res = self.layout_model.predict(new_image)
# for res in layout_res:
# p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
# p1 = p1 - paste_x + xmin
# p2 = p2 - paste_y + ymin
# p3 = p3 - paste_x + xmin
# p4 = p4 - paste_y + ymin
# p5 = p5 - paste_x + xmin
# p6 = p6 - paste_y + ymin
# p7 = p7 - paste_x + xmin
# p8 = p8 - paste_y + ymin
# res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
# else:
layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2)
......@@ -234,11 +209,11 @@ class CustomPEKModel:
ocr_start = time.time()
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
new_image, useful_list = crop_img(res, image, crop_paste_x=50, crop_paste_y=50)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
# OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
if self.apply_ocr:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
......@@ -260,7 +235,7 @@ class CustomPEKModel:
if self.apply_table:
table_start = time.time()
for res in table_res_list:
new_image, _ = crop_img(res, pil_img)
new_image, _ = crop_img(res, image)
single_table_start_time = time.time()
html_code = None
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
......
......@@ -3,8 +3,6 @@ import os
from pathlib import Path
import yaml
from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from magic_pdf.config.constants import MODEL_NAME
......@@ -42,7 +40,7 @@ def get_text_images(simple_images):
)
text_images = []
for simple_image in simple_images:
image = Image.fromarray(simple_image['img'])
image = simple_image['img']
layout_res = temp_layout_model.predict(image)
# 给textblock截图
for res in layout_res:
......@@ -51,7 +49,7 @@ def get_text_images(simple_images):
# 初步清洗(宽和高都小于100)
if x2 - x1 < 100 and y2 - y1 < 100:
continue
text_images.append(image.crop((x1, y1, x2, y2)))
text_images.append(image[y1:y2, x1:x2])
return text_images
......
......@@ -2,9 +2,9 @@
import time
from collections import Counter
from uuid import uuid4
import cv2
import numpy as np
import torch
from PIL import Image
from loguru import logger
from ultralytics import YOLO
......@@ -29,7 +29,7 @@ def split_images(image, result_images=None):
if result_images is None:
result_images = []
width, height = image.size
height, width = image.shape[:2]
long_side = max(width, height) # 获取较长边长度
if long_side <= 400:
......@@ -45,7 +45,7 @@ def split_images(image, result_images=None):
if x + new_long_side > width:
continue
box = (x, 0, x + new_long_side, height)
sub_image = image.crop(box)
sub_image = image[0:height, x:x + new_long_side]
sub_images.append(sub_image)
else: # 如果高度是较长边
for y in range(0, height, new_long_side):
......@@ -53,7 +53,7 @@ def split_images(image, result_images=None):
if y + new_long_side > height:
continue
box = (0, y, width, y + new_long_side)
sub_image = image.crop(box)
sub_image = image[y:y + new_long_side, 0:width]
sub_images.append(sub_image)
for sub_image in sub_images:
......@@ -64,24 +64,32 @@ def split_images(image, result_images=None):
def resize_images_to_224(image):
"""
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小。
Works directly with NumPy arrays.
"""
try:
width, height = image.size
height, width = image.shape[:2]
if width < 224 or height < 224:
new_image = Image.new('RGB', (224, 224), (0, 0, 0))
paste_x = (224 - width) // 2
paste_y = (224 - height) // 2
new_image.paste(image, (paste_x, paste_y))
# Create black background
new_image = np.zeros((224, 224, 3), dtype=np.uint8)
# Calculate paste position (ensure they're not negative)
paste_x = max(0, (224 - width) // 2)
paste_y = max(0, (224 - height) // 2)
# Make sure we don't exceed the boundaries of new_image
paste_width = min(width, 224)
paste_height = min(height, 224)
# Paste original image onto black background
new_image[paste_y:paste_y + paste_height, paste_x:paste_x + paste_width] = image[:paste_height, :paste_width]
image = new_image
else:
image = image.resize((224, 224), Image.Resampling.LANCZOS)
# Resize using cv2
image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LANCZOS4)
# uuid = str(uuid4())
# image.save(f"/tmp/{uuid}.jpg")
return image
except Exception as e:
logger.exception(e)
logger.exception(f"Error in resize_images_to_224: {e}")
return None
class YOLOv11LangDetModel(object):
......@@ -96,8 +104,7 @@ class YOLOv11LangDetModel(object):
def do_detect(self, images: list):
all_images = []
for image in images:
width, height = image.size
# logger.info(f"image size: {width} x {height}")
height, width = image.shape[:2]
if width < 100 and height < 100:
continue
temp_images = split_images(image)
......
......@@ -4,7 +4,6 @@ import re
import torch
import unimernet.tasks as tasks
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from unimernet.common.config import Config
......@@ -19,16 +18,6 @@ class MathDataset(Dataset):
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# if not pil image, then convert to pil image
if isinstance(self.image_paths[idx], str):
raw_image = Image.open(self.image_paths[idx])
else:
raw_image = self.image_paths[idx]
if self.transform:
image = self.transform(raw_image)
return image
def latex_rm_whitespace(s: str):
"""Remove unnecessary whitespace from LaTeX code."""
......@@ -84,8 +73,7 @@ class UnimernetModel(object):
"latex": "",
}
formula_list.append(new_item)
pil_img = Image.fromarray(image)
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
bbox_img = image[ymin:ymax, xmin:xmax]
mf_image_list.append(bbox_img)
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
......@@ -100,46 +88,6 @@ class UnimernetModel(object):
res["latex"] = latex_rm_whitespace(latex)
return formula_list
# def batch_predict(
# self, images_mfd_res: list, images: list, batch_size: int = 64
# ) -> list:
# images_formula_list = []
# mf_image_list = []
# backfill_list = []
# for image_index in range(len(images_mfd_res)):
# mfd_res = images_mfd_res[image_index]
# pil_img = Image.fromarray(images[image_index])
# formula_list = []
#
# for xyxy, conf, cla in zip(
# mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
# ):
# xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
# new_item = {
# "category_id": 13 + int(cla.item()),
# "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
# "score": round(float(conf.item()), 2),
# "latex": "",
# }
# formula_list.append(new_item)
# bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
# mf_image_list.append(bbox_img)
#
# images_formula_list.append(formula_list)
# backfill_list += formula_list
#
# dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# mfr_res = []
# for mf_img in dataloader:
# mf_img = mf_img.to(self.device)
# with torch.no_grad():
# output = self.model.generate({"image": mf_img})
# mfr_res.extend(output["pred_str"])
# for res, latex in zip(backfill_list, mfr_res):
# res["latex"] = latex_rm_whitespace(latex)
# return images_formula_list
def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
images_formula_list = []
mf_image_list = []
......@@ -149,7 +97,7 @@ class UnimernetModel(object):
# Collect images with their original indices
for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index]
pil_img = Image.fromarray(images[image_index])
np_array_image = images[image_index]
formula_list = []
for idx, (xyxy, conf, cla) in enumerate(zip(
......@@ -163,7 +111,7 @@ class UnimernetModel(object):
"latex": "",
}
formula_list.append(new_item)
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
bbox_img = np_array_image[ymin:ymax, xmin:xmax]
area = (xmax - xmin) * (ymax - ymin)
curr_idx = len(mf_image_list)
......
import time
import torch
from PIL import Image
from loguru import logger
import numpy as np
from magic_pdf.libs.clean_memory import clean_memory
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
# Create a white background with an additional width and height of 50
# Calculate new dimensions
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
# Crop image
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
cropped_img = input_pil_img.crop(crop_box)
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
# Create a white background array
return_image = np.ones((crop_new_height, crop_new_width, 3), dtype=np.uint8) * 255
# Crop the original image using numpy slicing
cropped_img = input_np_img[crop_ymin:crop_ymax, crop_xmin:crop_xmax]
# Paste the cropped image onto the white background
return_image[crop_paste_y:crop_paste_y + (crop_ymax - crop_ymin),
crop_paste_x:crop_paste_x + (crop_xmax - crop_xmin)] = cropped_img
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width,
crop_new_height]
return return_image, return_list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment