"...llama_fastertransformer.git" did not exist on "a0382aa63ee128184b51e9841bed228904af3849"
Commit 1b34f7e4 authored by myhloli's avatar myhloli
Browse files

refactor(magic_pdf): replace PIL with NumPy for image processing

- Remove PIL usage across multiple files
- Convert image processing functions to use NumPy arrays
- Update crop_img function to work with NumPy arrays
- Modify image loading and resizing to use NumPy and OpenCV
- Clean up unused imports and comments related to PIL
parent cf15c065
...@@ -3,10 +3,8 @@ import fitz ...@@ -3,10 +3,8 @@ import fitz
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from magic_pdf.utils.annotations import ImportPIL
@ImportPIL
def fitz_doc_to_image(doc, dpi=200) -> dict: def fitz_doc_to_image(doc, dpi=200) -> dict:
"""Convert fitz.Document to image, Then convert the image to numpy array. """Convert fitz.Document to image, Then convert the image to numpy array.
...@@ -17,7 +15,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict: ...@@ -17,7 +15,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
Returns: Returns:
dict: {'img': numpy array, 'width': width, 'height': height } dict: {'img': numpy array, 'width': width, 'height': height }
""" """
from PIL import Image
mat = fitz.Matrix(dpi / 72, dpi / 72) mat = fitz.Matrix(dpi / 72, dpi / 72)
pm = doc.get_pixmap(matrix=mat, alpha=False) pm = doc.get_pixmap(matrix=mat, alpha=False)
...@@ -25,8 +22,8 @@ def fitz_doc_to_image(doc, dpi=200) -> dict: ...@@ -25,8 +22,8 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
if pm.width > 4500 or pm.height > 4500: if pm.width > 4500 or pm.height > 4500:
pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) pm = doc.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples) # Convert pixmap samples directly to numpy array
img = np.array(img) img = np.frombuffer(pm.samples, dtype=np.uint8).reshape(pm.height, pm.width, 3)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height} img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
...@@ -34,7 +31,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict: ...@@ -34,7 +31,6 @@ def fitz_doc_to_image(doc, dpi=200) -> dict:
@ImportPIL @ImportPIL
def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list: def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
from PIL import Image
images = [] images = []
with fitz.open('pdf', pdf_bytes) as doc: with fitz.open('pdf', pdf_bytes) as doc:
pdf_page_num = doc.page_count pdf_page_num = doc.page_count
...@@ -57,8 +53,9 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id ...@@ -57,8 +53,9 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
if pm.width > 4500 or pm.height > 4500: if pm.width > 4500 or pm.height > 4500:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples) # Convert pixmap samples directly to numpy array
img = np.array(img) img = np.frombuffer(pm.samples, dtype=np.uint8).reshape(pm.height, pm.width, 3)
img_dict = {'img': img, 'width': pm.width, 'height': pm.height} img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
else: else:
img_dict = {'img': [], 'width': 0, 'height': 0} img_dict = {'img': [], 'width': 0, 'height': 0}
......
...@@ -44,14 +44,19 @@ def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"): ...@@ -44,14 +44,19 @@ def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
# 截取图片 # 截取图片
pix = page.get_pixmap(clip=rect, matrix=zoom) pix = page.get_pixmap(clip=rect, matrix=zoom)
# 将字节数据转换为文件对象
image_file = BytesIO(pix.tobytes(output='png'))
# 使用 Pillow 打开图像
pil_image = Image.open(image_file)
if mode == "cv2": if mode == "cv2":
image_result = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2BGR) # 直接转换为numpy数组供cv2使用
img_array = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
# PyMuPDF使用RGB顺序,而cv2使用BGR顺序
if pix.n == 3 or pix.n == 4:
image_result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
else:
image_result = img_array
elif mode == "pillow": elif mode == "pillow":
image_result = pil_image # 将字节数据转换为文件对象
image_file = BytesIO(pix.tobytes(output='png'))
# 使用 Pillow 打开图像
image_result = Image.open(image_file)
else: else:
raise ValueError(f"mode: {mode} is not supported.") raise ValueError(f"mode: {mode} is not supported.")
......
import time import time
import cv2 import cv2
import numpy as np
import torch import torch
from loguru import logger from loguru import logger
from PIL import Image
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
# from magic_pdf.data.dataset import Dataset
# from magic_pdf.libs.clean_memory import clean_memory
# from magic_pdf.libs.config_reader import get_device
# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res) clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list) get_adjusted_mfdetrec_res, get_ocr_result_list)
# from magic_pdf.operators.models import InferenceResult
YOLO_LAYOUT_BASE_BATCH_SIZE = 1 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
MFD_BASE_BATCH_SIZE = 1 MFD_BASE_BATCH_SIZE = 1
...@@ -31,7 +23,6 @@ class BatchAnalyze: ...@@ -31,7 +23,6 @@ class BatchAnalyze:
def __call__(self, images: list) -> list: def __call__(self, images: list) -> list:
images_layout_res = [] images_layout_res = []
layout_start_time = time.time() layout_start_time = time.time()
if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3: if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3 # layoutlmv3
...@@ -41,36 +32,14 @@ class BatchAnalyze: ...@@ -41,36 +32,14 @@ class BatchAnalyze:
elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO: elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo # doclayout_yolo
layout_images = [] layout_images = []
modified_images = []
for image_index, image in enumerate(images): for image_index, image in enumerate(images):
pil_img = Image.fromarray(image) layout_images.append(image)
# width, height = pil_img.size
# if height > width:
# input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
# new_image, useful_list = crop_img(
# input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
# )
# layout_images.append(new_image)
# modified_images.append([image_index, useful_list])
# else:
layout_images.append(pil_img)
images_layout_res += self.model.layout_model.batch_predict( images_layout_res += self.model.layout_model.batch_predict(
# layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE # layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
) )
for image_index, useful_list in modified_images:
for res in images_layout_res[image_index]:
for i in range(len(res['poly'])):
if i % 2 == 0:
res['poly'][i] = (
res['poly'][i] - useful_list[0] + useful_list[2]
)
else:
res['poly'][i] = (
res['poly'][i] - useful_list[1] + useful_list[3]
)
logger.info( logger.info(
f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}' f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
) )
...@@ -111,7 +80,7 @@ class BatchAnalyze: ...@@ -111,7 +80,7 @@ class BatchAnalyze:
# reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze # reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
for index in range(len(images)): for index in range(len(images)):
layout_res = images_layout_res[index] layout_res = images_layout_res[index]
pil_img = Image.fromarray(images[index]) np_array_img = images[index]
ocr_res_list, table_res_list, single_page_mfdetrec_res = ( ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res) get_res_list_from_layout_res(layout_res)
...@@ -121,14 +90,14 @@ class BatchAnalyze: ...@@ -121,14 +90,14 @@ class BatchAnalyze:
# Process each area that requires OCR processing # Process each area that requires OCR processing
for res in ocr_res_list: for res in ocr_res_list:
new_image, useful_list = crop_img( new_image, useful_list = crop_img(
res, pil_img, crop_paste_x=50, crop_paste_y=50 res, np_array_img, crop_paste_x=50, crop_paste_y=50
) )
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res( adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
single_page_mfdetrec_res, useful_list single_page_mfdetrec_res, useful_list
) )
# OCR recognition # OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
if self.model.apply_ocr: if self.model.apply_ocr:
ocr_res = self.model.ocr_model.ocr( ocr_res = self.model.ocr_model.ocr(
...@@ -150,7 +119,7 @@ class BatchAnalyze: ...@@ -150,7 +119,7 @@ class BatchAnalyze:
if self.model.apply_table: if self.model.apply_table:
table_start = time.time() table_start = time.time()
for res in table_res_list: for res in table_res_list:
new_image, _ = crop_img(res, pil_img) new_image, _ = crop_img(res, np_array_img)
single_table_start_time = time.time() single_table_start_time = time.time()
html_code = None html_code = None
if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE: if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
...@@ -197,83 +166,3 @@ class BatchAnalyze: ...@@ -197,83 +166,3 @@ class BatchAnalyze:
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}') logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
return images_layout_res return images_layout_res
# def doc_batch_analyze(
# dataset: Dataset,
# ocr: bool = False,
# show_log: bool = False,
# start_page_id=0,
# end_page_id=None,
# lang=None,
# layout_model=None,
# formula_enable=None,
# table_enable=None,
# batch_ratio: int | None = None,
# ) -> InferenceResult:
# """Perform batch analysis on a document dataset.
#
# Args:
# dataset (Dataset): The dataset containing document pages to be analyzed.
# ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
# show_log (bool, optional): Flag to enable logging. Defaults to False.
# start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
# end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
# lang (str, optional): Language for OCR. Defaults to None.
# layout_model (optional): Layout model to be used for analysis. Defaults to None.
# formula_enable (optional): Flag to enable formula detection. Defaults to None.
# table_enable (optional): Flag to enable table detection. Defaults to None.
# batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
#
# Raises:
# CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
#
# Returns:
# InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
# """
#
# if not torch.cuda.is_available():
# raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
#
# lang = None if lang == '' else lang
# # TODO: auto detect batch size
# batch_ratio = 1 if batch_ratio is None else batch_ratio
# end_page_id = end_page_id if end_page_id else len(dataset)
#
# model_manager = ModelSingleton()
# custom_model: CustomPEKModel = model_manager.get_model(
# ocr, show_log, lang, layout_model, formula_enable, table_enable
# )
# batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
#
# model_json = []
#
# # batch analyze
# images = []
# for index in range(len(dataset)):
# if start_page_id <= index <= end_page_id:
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# images.append(img_dict['img'])
# analyze_result = batch_model(images)
#
# for index in range(len(dataset)):
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# page_width = img_dict['width']
# page_height = img_dict['height']
# if start_page_id <= index <= end_page_id:
# result = analyze_result.pop(0)
# else:
# result = []
#
# page_info = {'page_no': index, 'height': page_height, 'width': page_width}
# page_dict = {'layout_dets': result, 'page_info': page_info}
# model_json.append(page_dict)
#
# # TODO: clean memory when gpu memory is not enough
# clean_memory_start_time = time.time()
# clean_memory(get_device())
# logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
#
# return InferenceResult(model_json, dataset)
...@@ -3,11 +3,9 @@ import os ...@@ -3,11 +3,9 @@ import os
import time import time
import cv2 import cv2
import numpy as np
import torch import torch
import yaml import yaml
from loguru import logger from loguru import logger
from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
...@@ -174,11 +172,6 @@ class CustomPEKModel: ...@@ -174,11 +172,6 @@ class CustomPEKModel:
logger.info('DocAnalysis init done!') logger.info('DocAnalysis init done!')
def __call__(self, image): def __call__(self, image):
pil_img = Image.fromarray(image)
width, height = pil_img.size
# logger.info(f'width: {width}, height: {height}')
# layout检测 # layout检测
layout_start = time.time() layout_start = time.time()
layout_res = [] layout_res = []
...@@ -186,24 +179,6 @@ class CustomPEKModel: ...@@ -186,24 +179,6 @@ class CustomPEKModel:
# layoutlmv3 # layoutlmv3
layout_res = self.layout_model(image, ignore_catids=[]) layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
# if height > width:
# input_res = {"poly":[0,0,width,0,width,height,0,height]}
# new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
# paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# layout_res = self.layout_model.predict(new_image)
# for res in layout_res:
# p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
# p1 = p1 - paste_x + xmin
# p2 = p2 - paste_y + ymin
# p3 = p3 - paste_x + xmin
# p4 = p4 - paste_y + ymin
# p5 = p5 - paste_x + xmin
# p6 = p6 - paste_y + ymin
# p7 = p7 - paste_x + xmin
# p8 = p8 - paste_y + ymin
# res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
# else:
layout_res = self.layout_model.predict(image) layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2) layout_cost = round(time.time() - layout_start, 2)
...@@ -234,11 +209,11 @@ class CustomPEKModel: ...@@ -234,11 +209,11 @@ class CustomPEKModel:
ocr_start = time.time() ocr_start = time.time()
# Process each area that requires OCR processing # Process each area that requires OCR processing
for res in ocr_res_list: for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50) new_image, useful_list = crop_img(res, image, crop_paste_x=50, crop_paste_y=50)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list) adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
# OCR recognition # OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
if self.apply_ocr: if self.apply_ocr:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0] ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
...@@ -260,7 +235,7 @@ class CustomPEKModel: ...@@ -260,7 +235,7 @@ class CustomPEKModel:
if self.apply_table: if self.apply_table:
table_start = time.time() table_start = time.time()
for res in table_res_list: for res in table_res_list:
new_image, _ = crop_img(res, pil_img) new_image, _ = crop_img(res, image)
single_table_start_time = time.time() single_table_start_time = time.time()
html_code = None html_code = None
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE: if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
......
...@@ -3,8 +3,6 @@ import os ...@@ -3,8 +3,6 @@ import os
from pathlib import Path from pathlib import Path
import yaml import yaml
from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
...@@ -42,7 +40,7 @@ def get_text_images(simple_images): ...@@ -42,7 +40,7 @@ def get_text_images(simple_images):
) )
text_images = [] text_images = []
for simple_image in simple_images: for simple_image in simple_images:
image = Image.fromarray(simple_image['img']) image = simple_image['img']
layout_res = temp_layout_model.predict(image) layout_res = temp_layout_model.predict(image)
# 给textblock截图 # 给textblock截图
for res in layout_res: for res in layout_res:
...@@ -51,7 +49,7 @@ def get_text_images(simple_images): ...@@ -51,7 +49,7 @@ def get_text_images(simple_images):
# 初步清洗(宽和高都小于100) # 初步清洗(宽和高都小于100)
if x2 - x1 < 100 and y2 - y1 < 100: if x2 - x1 < 100 and y2 - y1 < 100:
continue continue
text_images.append(image.crop((x1, y1, x2, y2))) text_images.append(image[y1:y2, x1:x2])
return text_images return text_images
......
...@@ -3,8 +3,8 @@ import time ...@@ -3,8 +3,8 @@ import time
from collections import Counter from collections import Counter
from uuid import uuid4 from uuid import uuid4
import numpy as np
import torch import torch
from PIL import Image
from loguru import logger from loguru import logger
from ultralytics import YOLO from ultralytics import YOLO
...@@ -64,21 +64,32 @@ def split_images(image, result_images=None): ...@@ -64,21 +64,32 @@ def split_images(image, result_images=None):
def resize_images_to_224(image): def resize_images_to_224(image):
""" """
若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。 若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小。
Works directly with NumPy arrays.
""" """
try: try:
width, height = image.size # Handle numpy array directly
if len(image.shape) == 3: # Color image
height, width, channels = image.shape
else: # Grayscale image
height, width = image.shape
image = np.stack([image] * 3, axis=2) # Convert to RGB
if width < 224 or height < 224: if width < 224 or height < 224:
new_image = Image.new('RGB', (224, 224), (0, 0, 0)) # Create black background
new_image = np.zeros((224, 224, 3), dtype=np.uint8)
# Calculate paste position
paste_x = (224 - width) // 2 paste_x = (224 - width) // 2
paste_y = (224 - height) // 2 paste_y = (224 - height) // 2
new_image.paste(image, (paste_x, paste_y)) # Paste original image onto black background
new_image[paste_y:paste_y + height, paste_x:paste_x + width] = image
image = new_image image = new_image
else: else:
image = image.resize((224, 224), Image.Resampling.LANCZOS) # Resize using cv2 functionality or numpy interpolation
# Method 1: Using cv2 (preferred for better quality)
import cv2
image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LANCZOS4)
# uuid = str(uuid4())
# image.save(f"/tmp/{uuid}.jpg")
return image return image
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
...@@ -96,8 +107,7 @@ class YOLOv11LangDetModel(object): ...@@ -96,8 +107,7 @@ class YOLOv11LangDetModel(object):
def do_detect(self, images: list): def do_detect(self, images: list):
all_images = [] all_images = []
for image in images: for image in images:
width, height = image.size height, width = image.shape[:2]
# logger.info(f"image size: {width} x {height}")
if width < 100 and height < 100: if width < 100 and height < 100:
continue continue
temp_images = split_images(image) temp_images = split_images(image)
......
...@@ -4,7 +4,6 @@ import re ...@@ -4,7 +4,6 @@ import re
import torch import torch
import unimernet.tasks as tasks import unimernet.tasks as tasks
from PIL import Image
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from torchvision import transforms from torchvision import transforms
from unimernet.common.config import Config from unimernet.common.config import Config
...@@ -100,45 +99,6 @@ class UnimernetModel(object): ...@@ -100,45 +99,6 @@ class UnimernetModel(object):
res["latex"] = latex_rm_whitespace(latex) res["latex"] = latex_rm_whitespace(latex)
return formula_list return formula_list
# def batch_predict(
# self, images_mfd_res: list, images: list, batch_size: int = 64
# ) -> list:
# images_formula_list = []
# mf_image_list = []
# backfill_list = []
# for image_index in range(len(images_mfd_res)):
# mfd_res = images_mfd_res[image_index]
# pil_img = Image.fromarray(images[image_index])
# formula_list = []
#
# for xyxy, conf, cla in zip(
# mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
# ):
# xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
# new_item = {
# "category_id": 13 + int(cla.item()),
# "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
# "score": round(float(conf.item()), 2),
# "latex": "",
# }
# formula_list.append(new_item)
# bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
# mf_image_list.append(bbox_img)
#
# images_formula_list.append(formula_list)
# backfill_list += formula_list
#
# dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# mfr_res = []
# for mf_img in dataloader:
# mf_img = mf_img.to(self.device)
# with torch.no_grad():
# output = self.model.generate({"image": mf_img})
# mfr_res.extend(output["pred_str"])
# for res, latex in zip(backfill_list, mfr_res):
# res["latex"] = latex_rm_whitespace(latex)
# return images_formula_list
def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
images_formula_list = [] images_formula_list = []
...@@ -149,7 +109,7 @@ class UnimernetModel(object): ...@@ -149,7 +109,7 @@ class UnimernetModel(object):
# Collect images with their original indices # Collect images with their original indices
for image_index in range(len(images_mfd_res)): for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index] mfd_res = images_mfd_res[image_index]
pil_img = Image.fromarray(images[image_index]) np_array_image = images[image_index]
formula_list = [] formula_list = []
for idx, (xyxy, conf, cla) in enumerate(zip( for idx, (xyxy, conf, cla) in enumerate(zip(
...@@ -163,7 +123,7 @@ class UnimernetModel(object): ...@@ -163,7 +123,7 @@ class UnimernetModel(object):
"latex": "", "latex": "",
} }
formula_list.append(new_item) formula_list.append(new_item)
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax)) bbox_img = np_array_image[ymin:ymax, xmin:xmax]
area = (xmax - xmin) * (ymax - ymin) area = (xmax - xmin) * (ymax - ymin)
curr_idx = len(mf_image_list) curr_idx = len(mf_image_list)
......
import time import time
import torch import torch
from PIL import Image
from loguru import logger from loguru import logger
import numpy as np
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0): def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1]) crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5]) crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
# Create a white background with an additional width and height of 50
# Calculate new dimensions
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2 crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2 crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
# Crop image # Create a white background array
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax) return_image = np.ones((crop_new_height, crop_new_width, 3), dtype=np.uint8) * 255
cropped_img = input_pil_img.crop(crop_box)
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y)) # Crop the original image using numpy slicing
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height] cropped_img = input_np_img[crop_ymin:crop_ymax, crop_xmin:crop_xmax]
# Paste the cropped image onto the white background
return_image[crop_paste_y:crop_paste_y + (crop_ymax - crop_ymin),
crop_paste_x:crop_paste_x + (crop_xmax - crop_xmin)] = cropped_img
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width,
crop_new_height]
return return_image, return_list return return_image, return_list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment