Commit af27c0cc authored by myhloli's avatar myhloli
Browse files

refactor(magic_pdf): support mps device and optimize image processing

- Add support for Apple M1 chips (mps device)
- Refactor image processing for better performance and compatibility
- Update model loading and inference for various devices
- Adjust batch processing and memory management
parent 31ebceb5
......@@ -256,27 +256,28 @@ def may_batch_image_analyze(
batch_ratio = 1
device = get_device()
npu_support = False
if str(device).startswith('npu'):
import torch_npu
if torch_npu.npu.is_available():
npu_support = True
torch.npu.set_compile_mode(jit_compile=False)
if torch.cuda.is_available() and device != 'cpu' or npu_support:
if str(device).startswith('npu') or str(device).startswith('cuda'):
gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
if gpu_memory is not None and gpu_memory >= 8:
if gpu_memory is not None:
if gpu_memory >= 20:
batch_ratio = 16
elif gpu_memory >= 15:
batch_ratio = 8
elif gpu_memory >= 10:
batch_ratio = 4
else:
elif gpu_memory >= 7:
batch_ratio = 2
else:
batch_ratio = 1
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
batch_analyze = True
elif str(device).startswith('mps'):
batch_analyze = True
doc_analyze_start = time.time()
if batch_analyze:
......
......@@ -118,7 +118,7 @@ class CustomPEKModel:
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path,
device='cpu' if str(self.device).startswith("mps") else self.device,
device=self.device,
)
# 初始化layout模型
......
......@@ -44,7 +44,6 @@ def split_images(image, result_images=None):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if x + new_long_side > width:
continue
box = (x, 0, x + new_long_side, height)
sub_image = image[0:height, x:x + new_long_side]
sub_images.append(sub_image)
else: # 如果高度是较长边
......@@ -52,7 +51,6 @@ def split_images(image, result_images=None):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if y + new_long_side > height:
continue
box = (0, y, width, y + new_long_side)
sub_image = image[y:y + new_long_side, 0:width]
sub_images.append(sub_image)
......
......@@ -4,6 +4,8 @@ from doclayout_yolo import YOLOv10
class DocLayoutYOLOModel(object):
def __init__(self, weight, device):
self.model = YOLOv10(weight)
if not device.startswith("cpu"):
self.model.half()
self.device = device
def predict(self, image):
......
......@@ -4,6 +4,8 @@ from ultralytics import YOLO
class YOLOv8MFDModel(object):
def __init__(self, weight, device="cpu"):
self.mfd_model = YOLO(weight)
if not device.startswith("cpu"):
self.mfd_model.half()
self.device = device
def predict(self, image):
......
import argparse
import os
import re
import torch
import unimernet.tasks as tasks
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from unimernet.common.config import Config
from unimernet.processors import load_processor
class MathDataset(Dataset):
......@@ -18,46 +10,26 @@ class MathDataset(Dataset):
def __len__(self):
return len(self.image_paths)
def latex_rm_whitespace(s: str):
"""Remove unnecessary whitespace from LaTeX code."""
text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
letter = "[a-zA-Z]"
noletter = "[\W_^\d]"
names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
news = s
while True:
s = news
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
if news == s:
break
return s
def __getitem__(self, idx):
raw_image = self.image_paths[idx]
if self.transform:
image = self.transform(raw_image)
return image
class UnimernetModel(object):
def __init__(self, weight_dir, cfg_path, _device_="cpu"):
args = argparse.Namespace(cfg_path=cfg_path, options=None)
cfg = Config(args)
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
cfg.config.model.model_config.model_name = weight_dir
cfg.config.model.tokenizer_config.path = weight_dir
task = tasks.setup_task(cfg)
self.model = task.build_model(cfg)
from .unimernet_hf import UnimernetModel
if _device_.startswith("mps"):
self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
else:
self.model = UnimernetModel.from_pretrained(weight_dir)
self.device = _device_
self.model.to(_device_)
if not _device_.startswith("cpu"):
self.model = self.model.to(dtype=torch.float16)
self.model.eval()
vis_processor = load_processor(
"formula_image_eval",
cfg.config.datasets.formula_rec_eval.vis_processor.eval,
)
self.mfr_transform = transforms.Compose(
[
vis_processor,
]
)
def predict(self, mfd_res, image):
formula_list = []
......@@ -76,16 +48,17 @@ class UnimernetModel(object):
bbox_img = image[ymin:ymax, xmin:xmax]
mf_image_list.append(bbox_img)
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
dataset = MathDataset(mf_image_list, transform=self.model.transform)
dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
mfr_res = []
for mf_img in dataloader:
mf_img = mf_img.to(dtype=self.model.dtype)
mf_img = mf_img.to(self.device)
with torch.no_grad():
output = self.model.generate({"image": mf_img})
mfr_res.extend(output["pred_str"])
mfr_res.extend(output["fixed_str"])
for res, latex in zip(formula_list, mfr_res):
res["latex"] = latex_rm_whitespace(latex)
res["latex"] = latex
return formula_list
def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
......@@ -130,22 +103,23 @@ class UnimernetModel(object):
index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
# Create dataset with sorted images
dataset = MathDataset(sorted_images, transform=self.mfr_transform)
dataset = MathDataset(sorted_images, transform=self.model.transform)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# Process batches and store results
mfr_res = []
for mf_img in dataloader:
mf_img = mf_img.to(dtype=self.model.dtype)
mf_img = mf_img.to(self.device)
with torch.no_grad():
output = self.model.generate({"image": mf_img})
mfr_res.extend(output["pred_str"])
mfr_res.extend(output["fixed_str"])
# Restore original order
unsorted_results = [""] * len(mfr_res)
for new_idx, latex in enumerate(mfr_res):
original_idx = index_mapping[new_idx]
unsorted_results[original_idx] = latex_rm_whitespace(latex)
unsorted_results[original_idx] = latex
# Fill results back
for res, latex in zip(backfill_list, unsorted_results):
......
from transformers.image_processing_utils import BaseImageProcessor
from PIL import Image, ImageOps
import numpy as np
import cv2
import albumentations as alb
from albumentations.pytorch import ToTensorV2
from torchvision.transforms.functional import resize
# TODO: dereference cv2 if possible
class UnimerSwinImageProcessor(BaseImageProcessor):
def __init__(
self,
image_size = [192, 672],
image_size = (192, 672),
):
self.input_size = [int(_) for _ in image_size]
assert len(self.input_size) == 2
......@@ -27,56 +25,90 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
def __call__(self, item):
image = self.prepare_input(item)
return self.transform(image=np.array(image))['image'][:1]
return self.transform(image=image)['image'][:1]
@staticmethod
def crop_margin(img: Image.Image) -> Image.Image:
data = np.array(img.convert("L"))
data = data.astype(np.uint8)
max_val = data.max()
min_val = data.min()
if max_val == min_val:
def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
"""Crop margins of image using NumPy operations"""
# Convert to grayscale if it's a color image
if len(img.shape) == 3 and img.shape[2] == 3:
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
else:
gray = img.copy()
# Normalize and threshold
if gray.max() == gray.min():
return img
data = (data - min_val) / (max_val - min_val) * 255
gray = 255 * (data < 200).astype(np.uint8)
coords = cv2.findNonZero(gray) # Find all non-zero points (text)
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
return img.crop((a, b, w + a, h + b))
normalized = (((gray - gray.min()) / (gray.max() - gray.min())) * 255).astype(np.uint8)
binary = 255 * (normalized < 200).astype(np.uint8)
# Find bounding box
coords = cv2.findNonZero(binary) # Find all non-zero points (text)
x, y, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
def prepare_input(self, img: Image.Image, random_padding: bool = False):
# Return cropped image
return img[y:y + h, x:x + w]
def prepare_input(self, img, random_padding: bool = False):
"""
Convert PIL Image to tensor according to specified input_size after following steps below:
- resize
- rotate (if align_long_axis is True and image is not aligned longer axis with canvas)
- pad
Convert PIL Image or numpy array to properly sized and padded image after:
- crop margins
- resize while maintaining aspect ratio
- pad to target size
"""
if img is None:
return
# crop margins
return None
try:
img = self.crop_margin(img.convert("RGB"))
except OSError:
img = self.crop_margin_numpy(img)
except Exception:
# might throw an error for broken files
return
return None
if img.shape[0] == 0 or img.shape[1] == 0:
return None
# Resize while preserving aspect ratio
h, w = img.shape[:2]
scale = min(self.input_size[0] / h, self.input_size[1] / w)
new_h, new_w = int(h * scale), int(w * scale)
resized_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
# Calculate padding
pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
# Create and apply padding
channels = 3 if len(img.shape) == 3 else 1
padded_img = np.full((self.input_size[0], self.input_size[1], channels), 255, dtype=np.uint8)
padded_img[pad_height:pad_height + new_h, pad_width:pad_width + new_w] = resized_img
return padded_img
def _calculate_padding(self, new_w, new_h, random_padding):
"""Calculate padding values for PIL images"""
delta_width = self.input_size[1] - new_w
delta_height = self.input_size[0] - new_h
pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
if img.height == 0 or img.width == 0:
return
return (
pad_width,
pad_height,
delta_width - pad_width,
delta_height - pad_height,
)
def _get_padding_values(self, new_w, new_h, random_padding):
"""Get padding values based on image dimensions and padding strategy"""
delta_width = self.input_size[1] - new_w
delta_height = self.input_size[0] - new_h
img = resize(img, min(self.input_size))
img.thumbnail((self.input_size[1], self.input_size[0]))
delta_width = self.input_size[1] - img.width
delta_height = self.input_size[0] - img.height
if random_padding:
pad_width = np.random.randint(low=0, high=delta_width + 1)
pad_height = np.random.randint(low=0, high=delta_height + 1)
else:
pad_width = delta_width // 2
pad_height = delta_height // 2
padding = (
pad_width,
pad_height,
delta_width - pad_width,
delta_height - pad_height,
)
return ImageOps.expand(img, padding)
return pad_width, pad_height
......@@ -492,7 +492,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
else:
return [[x0, y0, x1, y1]]
# @measure_time
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
page_line_list = []
......
......@@ -2,7 +2,7 @@ weights:
layoutlmv3: Layout/LayoutLMv3/model_final.pth
doclayout_yolo: Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt
yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
unimernet_small: MFR/unimernet_small_2501
unimernet_small: MFR/unimernet_hf_small_2503
struct_eqtable: TabRec/StructEqTable
tablemaster: TabRec/TableMaster
rapid_table: TabRec/RapidTable
\ No newline at end of file
......@@ -7,7 +7,8 @@ numpy>=1.21.6,<2.0.0
pydantic>=2.7.2
PyMuPDF>=1.24.9,<=1.24.14
scikit-learn>=1.0.2
torch>=2.2.2
transformers
torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
torchvision
transformers>=4.49.0
pdfminer.six==20231228
# The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.
......@@ -36,25 +36,19 @@ if __name__ == '__main__':
"paddlepaddle==3.0.0b1;platform_system=='Linux'",
"paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'",
],
"full": ["unimernet==0.2.3", # unimernet升级0.2.3,移除torchtext/eva-decord的依赖
"torch>=2.2.2,<=2.3.1", # torch2.4.0及之后版本未测试,先卡住版本上限
"torchvision>=0.17.2,<=0.18.1", # torchvision 受torch版本约束
"full": [
"matplotlib<=3.9.0;platform_system=='Windows'", # 3.9.1及之后不提供windows的预编译包,避免一些没有编译环境的windows设备安装失败
"matplotlib;platform_system=='Linux' or platform_system=='Darwin'", # linux 和 macos 不应限制matplotlib的最高版本,以避免无法更新导致的一些bug
"ultralytics>=8.3.48", # yolov8,公式检测
"paddleocr==2.7.3", # 2.8.0及2.8.1版本与detectron2有冲突,需锁定2.7.3
"paddlepaddle==3.0.0rc1;platform_system=='Linux' or platform_system=='Darwin'", # 解决linux的段异常问题
"paddlepaddle==2.6.1;platform_system=='Windows'", # windows版本3.0.0效率下降,需锁定2.6.1
"struct-eqtable==0.3.2", # 表格解析
"einops", # struct-eqtable依赖
"accelerate", # struct-eqtable依赖
"doclayout_yolo==0.0.2b1", # doclayout_yolo
"rapidocr-paddle>=1.4.5,<2.0.0", # rapidocr-paddle
"rapidocr_onnxruntime>=1.4.4,<2.0.0",
"rapid_table>=1.0.3,<2.0.0", # rapid_table
"PyYAML", # yaml
"openai", # openai SDK
"detectron2"
],
"old_linux":[
"albumentations<=1.4.20", # 1.4.21引入的simsimd不支持2019年及更早的linux系统
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment