Commit af27c0cc authored by myhloli's avatar myhloli
Browse files

refactor(magic_pdf): support mps device and optimize image processing

- Add support for Apple M1 chips (mps device)
- Refactor image processing for better performance and compatibility
- Update model loading and inference for various devices
- Adjust batch processing and memory management
parent 31ebceb5
...@@ -256,27 +256,28 @@ def may_batch_image_analyze( ...@@ -256,27 +256,28 @@ def may_batch_image_analyze(
batch_ratio = 1 batch_ratio = 1
device = get_device() device = get_device()
npu_support = False
if str(device).startswith('npu'): if str(device).startswith('npu'):
import torch_npu import torch_npu
if torch_npu.npu.is_available(): if torch_npu.npu.is_available():
npu_support = True
torch.npu.set_compile_mode(jit_compile=False) torch.npu.set_compile_mode(jit_compile=False)
if torch.cuda.is_available() and device != 'cpu' or npu_support: if str(device).startswith('npu') or str(device).startswith('cuda'):
gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device)))) gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
if gpu_memory is not None and gpu_memory >= 8: if gpu_memory is not None:
if gpu_memory >= 20: if gpu_memory >= 20:
batch_ratio = 16 batch_ratio = 16
elif gpu_memory >= 15: elif gpu_memory >= 15:
batch_ratio = 8 batch_ratio = 8
elif gpu_memory >= 10: elif gpu_memory >= 10:
batch_ratio = 4 batch_ratio = 4
else: elif gpu_memory >= 7:
batch_ratio = 2 batch_ratio = 2
else:
batch_ratio = 1
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}') logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
batch_analyze = True batch_analyze = True
elif str(device).startswith('mps'):
batch_analyze = True
doc_analyze_start = time.time() doc_analyze_start = time.time()
if batch_analyze: if batch_analyze:
......
...@@ -118,7 +118,7 @@ class CustomPEKModel: ...@@ -118,7 +118,7 @@ class CustomPEKModel:
atom_model_name=AtomicModel.MFR, atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir, mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path, mfr_cfg_path=mfr_cfg_path,
device='cpu' if str(self.device).startswith("mps") else self.device, device=self.device,
) )
# 初始化layout模型 # 初始化layout模型
......
...@@ -44,7 +44,6 @@ def split_images(image, result_images=None): ...@@ -44,7 +44,6 @@ def split_images(image, result_images=None):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作 # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if x + new_long_side > width: if x + new_long_side > width:
continue continue
box = (x, 0, x + new_long_side, height)
sub_image = image[0:height, x:x + new_long_side] sub_image = image[0:height, x:x + new_long_side]
sub_images.append(sub_image) sub_images.append(sub_image)
else: # 如果高度是较长边 else: # 如果高度是较长边
...@@ -52,7 +51,6 @@ def split_images(image, result_images=None): ...@@ -52,7 +51,6 @@ def split_images(image, result_images=None):
# 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作 # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
if y + new_long_side > height: if y + new_long_side > height:
continue continue
box = (0, y, width, y + new_long_side)
sub_image = image[y:y + new_long_side, 0:width] sub_image = image[y:y + new_long_side, 0:width]
sub_images.append(sub_image) sub_images.append(sub_image)
......
...@@ -4,6 +4,8 @@ from doclayout_yolo import YOLOv10 ...@@ -4,6 +4,8 @@ from doclayout_yolo import YOLOv10
class DocLayoutYOLOModel(object): class DocLayoutYOLOModel(object):
def __init__(self, weight, device): def __init__(self, weight, device):
self.model = YOLOv10(weight) self.model = YOLOv10(weight)
if not device.startswith("cpu"):
self.model.half()
self.device = device self.device = device
def predict(self, image): def predict(self, image):
......
...@@ -4,6 +4,8 @@ from ultralytics import YOLO ...@@ -4,6 +4,8 @@ from ultralytics import YOLO
class YOLOv8MFDModel(object): class YOLOv8MFDModel(object):
def __init__(self, weight, device="cpu"): def __init__(self, weight, device="cpu"):
self.mfd_model = YOLO(weight) self.mfd_model = YOLO(weight)
if not device.startswith("cpu"):
self.mfd_model.half()
self.device = device self.device = device
def predict(self, image): def predict(self, image):
......
import argparse
import os
import re
import torch import torch
import unimernet.tasks as tasks
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from unimernet.common.config import Config
from unimernet.processors import load_processor
class MathDataset(Dataset): class MathDataset(Dataset):
...@@ -18,46 +10,26 @@ class MathDataset(Dataset): ...@@ -18,46 +10,26 @@ class MathDataset(Dataset):
def __len__(self): def __len__(self):
return len(self.image_paths) return len(self.image_paths)
def __getitem__(self, idx):
def latex_rm_whitespace(s: str): raw_image = self.image_paths[idx]
"""Remove unnecessary whitespace from LaTeX code.""" if self.transform:
text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})" image = self.transform(raw_image)
letter = "[a-zA-Z]" return image
noletter = "[\W_^\d]"
names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
news = s
while True:
s = news
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
if news == s:
break
return s
class UnimernetModel(object): class UnimernetModel(object):
def __init__(self, weight_dir, cfg_path, _device_="cpu"): def __init__(self, weight_dir, cfg_path, _device_="cpu"):
args = argparse.Namespace(cfg_path=cfg_path, options=None) from .unimernet_hf import UnimernetModel
cfg = Config(args) if _device_.startswith("mps"):
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth") self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
cfg.config.model.model_config.model_name = weight_dir else:
cfg.config.model.tokenizer_config.path = weight_dir self.model = UnimernetModel.from_pretrained(weight_dir)
task = tasks.setup_task(cfg)
self.model = task.build_model(cfg)
self.device = _device_ self.device = _device_
self.model.to(_device_) self.model.to(_device_)
if not _device_.startswith("cpu"):
self.model = self.model.to(dtype=torch.float16)
self.model.eval() self.model.eval()
vis_processor = load_processor(
"formula_image_eval",
cfg.config.datasets.formula_rec_eval.vis_processor.eval,
)
self.mfr_transform = transforms.Compose(
[
vis_processor,
]
)
def predict(self, mfd_res, image): def predict(self, mfd_res, image):
formula_list = [] formula_list = []
...@@ -76,16 +48,17 @@ class UnimernetModel(object): ...@@ -76,16 +48,17 @@ class UnimernetModel(object):
bbox_img = image[ymin:ymax, xmin:xmax] bbox_img = image[ymin:ymax, xmin:xmax]
mf_image_list.append(bbox_img) mf_image_list.append(bbox_img)
dataset = MathDataset(mf_image_list, transform=self.mfr_transform) dataset = MathDataset(mf_image_list, transform=self.model.transform)
dataloader = DataLoader(dataset, batch_size=32, num_workers=0) dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
mfr_res = [] mfr_res = []
for mf_img in dataloader: for mf_img in dataloader:
mf_img = mf_img.to(dtype=self.model.dtype)
mf_img = mf_img.to(self.device) mf_img = mf_img.to(self.device)
with torch.no_grad(): with torch.no_grad():
output = self.model.generate({"image": mf_img}) output = self.model.generate({"image": mf_img})
mfr_res.extend(output["pred_str"]) mfr_res.extend(output["fixed_str"])
for res, latex in zip(formula_list, mfr_res): for res, latex in zip(formula_list, mfr_res):
res["latex"] = latex_rm_whitespace(latex) res["latex"] = latex
return formula_list return formula_list
def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list: def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
...@@ -130,22 +103,23 @@ class UnimernetModel(object): ...@@ -130,22 +103,23 @@ class UnimernetModel(object):
index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)} index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
# Create dataset with sorted images # Create dataset with sorted images
dataset = MathDataset(sorted_images, transform=self.mfr_transform) dataset = MathDataset(sorted_images, transform=self.model.transform)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
# Process batches and store results # Process batches and store results
mfr_res = [] mfr_res = []
for mf_img in dataloader: for mf_img in dataloader:
mf_img = mf_img.to(dtype=self.model.dtype)
mf_img = mf_img.to(self.device) mf_img = mf_img.to(self.device)
with torch.no_grad(): with torch.no_grad():
output = self.model.generate({"image": mf_img}) output = self.model.generate({"image": mf_img})
mfr_res.extend(output["pred_str"]) mfr_res.extend(output["fixed_str"])
# Restore original order # Restore original order
unsorted_results = [""] * len(mfr_res) unsorted_results = [""] * len(mfr_res)
for new_idx, latex in enumerate(mfr_res): for new_idx, latex in enumerate(mfr_res):
original_idx = index_mapping[new_idx] original_idx = index_mapping[new_idx]
unsorted_results[original_idx] = latex_rm_whitespace(latex) unsorted_results[original_idx] = latex
# Fill results back # Fill results back
for res, latex in zip(backfill_list, unsorted_results): for res, latex in zip(backfill_list, unsorted_results):
......
from transformers.image_processing_utils import BaseImageProcessor from transformers.image_processing_utils import BaseImageProcessor
from PIL import Image, ImageOps
import numpy as np import numpy as np
import cv2 import cv2
import albumentations as alb import albumentations as alb
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from torchvision.transforms.functional import resize
# TODO: dereference cv2 if possible # TODO: dereference cv2 if possible
class UnimerSwinImageProcessor(BaseImageProcessor): class UnimerSwinImageProcessor(BaseImageProcessor):
def __init__( def __init__(
self, self,
image_size = [192, 672], image_size = (192, 672),
): ):
self.input_size = [int(_) for _ in image_size] self.input_size = [int(_) for _ in image_size]
assert len(self.input_size) == 2 assert len(self.input_size) == 2
...@@ -27,56 +25,90 @@ class UnimerSwinImageProcessor(BaseImageProcessor): ...@@ -27,56 +25,90 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
def __call__(self, item): def __call__(self, item):
image = self.prepare_input(item) image = self.prepare_input(item)
return self.transform(image=np.array(image))['image'][:1] return self.transform(image=image)['image'][:1]
@staticmethod @staticmethod
def crop_margin(img: Image.Image) -> Image.Image: def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
data = np.array(img.convert("L")) """Crop margins of image using NumPy operations"""
data = data.astype(np.uint8) # Convert to grayscale if it's a color image
max_val = data.max() if len(img.shape) == 3 and img.shape[2] == 3:
min_val = data.min() gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
if max_val == min_val: else:
gray = img.copy()
# Normalize and threshold
if gray.max() == gray.min():
return img return img
data = (data - min_val) / (max_val - min_val) * 255
gray = 255 * (data < 200).astype(np.uint8)
coords = cv2.findNonZero(gray) # Find all non-zero points (text) normalized = (((gray - gray.min()) / (gray.max() - gray.min())) * 255).astype(np.uint8)
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box binary = 255 * (normalized < 200).astype(np.uint8)
return img.crop((a, b, w + a, h + b))
# Find bounding box
coords = cv2.findNonZero(binary) # Find all non-zero points (text)
x, y, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
def prepare_input(self, img: Image.Image, random_padding: bool = False): # Return cropped image
return img[y:y + h, x:x + w]
def prepare_input(self, img, random_padding: bool = False):
""" """
Convert PIL Image to tensor according to specified input_size after following steps below: Convert PIL Image or numpy array to properly sized and padded image after:
- resize - crop margins
- rotate (if align_long_axis is True and image is not aligned longer axis with canvas) - resize while maintaining aspect ratio
- pad - pad to target size
""" """
if img is None: if img is None:
return return None
# crop margins
try: try:
img = self.crop_margin(img.convert("RGB")) img = self.crop_margin_numpy(img)
except OSError: except Exception:
# might throw an error for broken files # might throw an error for broken files
return return None
if img.shape[0] == 0 or img.shape[1] == 0:
return None
# Resize while preserving aspect ratio
h, w = img.shape[:2]
scale = min(self.input_size[0] / h, self.input_size[1] / w)
new_h, new_w = int(h * scale), int(w * scale)
resized_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
# Calculate padding
pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
# Create and apply padding
channels = 3 if len(img.shape) == 3 else 1
padded_img = np.full((self.input_size[0], self.input_size[1], channels), 255, dtype=np.uint8)
padded_img[pad_height:pad_height + new_h, pad_width:pad_width + new_w] = resized_img
return padded_img
def _calculate_padding(self, new_w, new_h, random_padding):
"""Calculate padding values for PIL images"""
delta_width = self.input_size[1] - new_w
delta_height = self.input_size[0] - new_h
pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
if img.height == 0 or img.width == 0: return (
return pad_width,
pad_height,
delta_width - pad_width,
delta_height - pad_height,
)
def _get_padding_values(self, new_w, new_h, random_padding):
"""Get padding values based on image dimensions and padding strategy"""
delta_width = self.input_size[1] - new_w
delta_height = self.input_size[0] - new_h
img = resize(img, min(self.input_size))
img.thumbnail((self.input_size[1], self.input_size[0]))
delta_width = self.input_size[1] - img.width
delta_height = self.input_size[0] - img.height
if random_padding: if random_padding:
pad_width = np.random.randint(low=0, high=delta_width + 1) pad_width = np.random.randint(low=0, high=delta_width + 1)
pad_height = np.random.randint(low=0, high=delta_height + 1) pad_height = np.random.randint(low=0, high=delta_height + 1)
else: else:
pad_width = delta_width // 2 pad_width = delta_width // 2
pad_height = delta_height // 2 pad_height = delta_height // 2
padding = (
pad_width, return pad_width, pad_height
pad_height,
delta_width - pad_width,
delta_height - pad_height,
)
return ImageOps.expand(img, padding)
...@@ -492,7 +492,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h): ...@@ -492,7 +492,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
else: else:
return [[x0, y0, x1, y1]] return [[x0, y0, x1, y1]]
# @measure_time
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
page_line_list = [] page_line_list = []
......
...@@ -2,7 +2,7 @@ weights: ...@@ -2,7 +2,7 @@ weights:
layoutlmv3: Layout/LayoutLMv3/model_final.pth layoutlmv3: Layout/LayoutLMv3/model_final.pth
doclayout_yolo: Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt doclayout_yolo: Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt
yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
unimernet_small: MFR/unimernet_small_2501 unimernet_small: MFR/unimernet_hf_small_2503
struct_eqtable: TabRec/StructEqTable struct_eqtable: TabRec/StructEqTable
tablemaster: TabRec/TableMaster tablemaster: TabRec/TableMaster
rapid_table: TabRec/RapidTable rapid_table: TabRec/RapidTable
\ No newline at end of file
...@@ -7,7 +7,8 @@ numpy>=1.21.6,<2.0.0 ...@@ -7,7 +7,8 @@ numpy>=1.21.6,<2.0.0
pydantic>=2.7.2 pydantic>=2.7.2
PyMuPDF>=1.24.9,<=1.24.14 PyMuPDF>=1.24.9,<=1.24.14
scikit-learn>=1.0.2 scikit-learn>=1.0.2
torch>=2.2.2 torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
transformers torchvision
transformers>=4.49.0
pdfminer.six==20231228 pdfminer.six==20231228
# The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator. # The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.
...@@ -36,25 +36,19 @@ if __name__ == '__main__': ...@@ -36,25 +36,19 @@ if __name__ == '__main__':
"paddlepaddle==3.0.0b1;platform_system=='Linux'", "paddlepaddle==3.0.0b1;platform_system=='Linux'",
"paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'", "paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'",
], ],
"full": ["unimernet==0.2.3", # unimernet升级0.2.3,移除torchtext/eva-decord的依赖 "full": [
"torch>=2.2.2,<=2.3.1", # torch2.4.0及之后版本未测试,先卡住版本上限
"torchvision>=0.17.2,<=0.18.1", # torchvision 受torch版本约束
"matplotlib<=3.9.0;platform_system=='Windows'", # 3.9.1及之后不提供windows的预编译包,避免一些没有编译环境的windows设备安装失败 "matplotlib<=3.9.0;platform_system=='Windows'", # 3.9.1及之后不提供windows的预编译包,避免一些没有编译环境的windows设备安装失败
"matplotlib;platform_system=='Linux' or platform_system=='Darwin'", # linux 和 macos 不应限制matplotlib的最高版本,以避免无法更新导致的一些bug "matplotlib;platform_system=='Linux' or platform_system=='Darwin'", # linux 和 macos 不应限制matplotlib的最高版本,以避免无法更新导致的一些bug
"ultralytics>=8.3.48", # yolov8,公式检测 "ultralytics>=8.3.48", # yolov8,公式检测
"paddleocr==2.7.3", # 2.8.0及2.8.1版本与detectron2有冲突,需锁定2.7.3 "paddleocr==2.7.3", # 2.8.0及2.8.1版本与detectron2有冲突,需锁定2.7.3
"paddlepaddle==3.0.0rc1;platform_system=='Linux' or platform_system=='Darwin'", # 解决linux的段异常问题 "paddlepaddle==3.0.0rc1;platform_system=='Linux' or platform_system=='Darwin'", # 解决linux的段异常问题
"paddlepaddle==2.6.1;platform_system=='Windows'", # windows版本3.0.0效率下降,需锁定2.6.1 "paddlepaddle==2.6.1;platform_system=='Windows'", # windows版本3.0.0效率下降,需锁定2.6.1
"struct-eqtable==0.3.2", # 表格解析
"einops", # struct-eqtable依赖
"accelerate", # struct-eqtable依赖
"doclayout_yolo==0.0.2b1", # doclayout_yolo "doclayout_yolo==0.0.2b1", # doclayout_yolo
"rapidocr-paddle>=1.4.5,<2.0.0", # rapidocr-paddle "rapidocr-paddle>=1.4.5,<2.0.0", # rapidocr-paddle
"rapidocr_onnxruntime>=1.4.4,<2.0.0", "rapidocr_onnxruntime>=1.4.4,<2.0.0",
"rapid_table>=1.0.3,<2.0.0", # rapid_table "rapid_table>=1.0.3,<2.0.0", # rapid_table
"PyYAML", # yaml "PyYAML", # yaml
"openai", # openai SDK "openai", # openai SDK
"detectron2"
], ],
"old_linux":[ "old_linux":[
"albumentations<=1.4.20", # 1.4.21引入的simsimd不支持2019年及更早的linux系统 "albumentations<=1.4.20", # 1.4.21引入的simsimd不支持2019年及更早的linux系统
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment