Unverified Commit 4bb54393 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1427 from opendatalab/release-1.0.0

Release 1.0.0
parents 04f084ac 1c9f9942
...@@ -8,14 +8,51 @@ class DocLayoutYOLOModel(object): ...@@ -8,14 +8,51 @@ class DocLayoutYOLOModel(object):
def predict(self, image): def predict(self, image):
layout_res = [] layout_res = []
doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0] doclayout_yolo_res = self.model.predict(
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), image, imgsz=1024, conf=0.25, iou=0.45, verbose=False, device=self.device
doclayout_yolo_res.boxes.cls.cpu()): )[0]
for xyxy, conf, cla in zip(
doclayout_yolo_res.boxes.xyxy.cpu(),
doclayout_yolo_res.boxes.conf.cpu(),
doclayout_yolo_res.boxes.cls.cpu(),
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = { new_item = {
'category_id': int(cla.item()), "category_id": int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 3), "score": round(float(conf.item()), 3),
} }
layout_res.append(new_item) layout_res.append(new_item)
return layout_res return layout_res
\ No newline at end of file
def batch_predict(self, images: list, batch_size: int) -> list:
images_layout_res = []
for index in range(0, len(images), batch_size):
doclayout_yolo_res = [
image_res.cpu()
for image_res in self.model.predict(
images[index : index + batch_size],
imgsz=1024,
conf=0.25,
iou=0.45,
verbose=False,
device=self.device,
)
]
for image_res in doclayout_yolo_res:
layout_res = []
for xyxy, conf, cla in zip(
image_res.boxes.xyxy,
image_res.boxes.conf,
image_res.boxes.cls,
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
"category_id": int(cla.item()),
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
"score": round(float(conf.item()), 3),
}
layout_res.append(new_item)
images_layout_res.append(layout_res)
return images_layout_res
...@@ -2,11 +2,30 @@ from ultralytics import YOLO ...@@ -2,11 +2,30 @@ from ultralytics import YOLO
class YOLOv8MFDModel(object): class YOLOv8MFDModel(object):
def __init__(self, weight, device='cpu'): def __init__(self, weight, device="cpu"):
self.mfd_model = YOLO(weight) self.mfd_model = YOLO(weight)
self.device = device self.device = device
def predict(self, image): def predict(self, image):
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0] mfd_res = self.mfd_model.predict(
image, imgsz=1888, conf=0.25, iou=0.45, verbose=False, device=self.device
)[0]
return mfd_res return mfd_res
def batch_predict(self, images: list, batch_size: int) -> list:
images_mfd_res = []
for index in range(0, len(images), batch_size):
mfd_res = [
image_res.cpu()
for image_res in self.mfd_model.predict(
images[index : index + batch_size],
imgsz=1888,
conf=0.25,
iou=0.45,
verbose=False,
device=self.device,
)
]
for image_res in mfd_res:
images_mfd_res.append(image_res)
return images_mfd_res
import os
import argparse import argparse
import os
import re import re
from PIL import Image
import torch import torch
from torch.utils.data import Dataset, DataLoader import unimernet.tasks as tasks
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms from torchvision import transforms
from unimernet.common.config import Config from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor from unimernet.processors import load_processor
...@@ -31,27 +31,25 @@ class MathDataset(Dataset): ...@@ -31,27 +31,25 @@ class MathDataset(Dataset):
def latex_rm_whitespace(s: str): def latex_rm_whitespace(s: str):
"""Remove unnecessary whitespace from LaTeX code. """Remove unnecessary whitespace from LaTeX code."""
""" text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})' letter = "[a-zA-Z]"
letter = '[a-zA-Z]' noletter = "[\W_^\d]"
noletter = '[\W_^\d]' names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
s = re.sub(text_reg, lambda match: str(names.pop(0)), s) s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
news = s news = s
while True: while True:
s = news s = news
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s) news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news) news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news) news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
if news == s: if news == s:
break break
return s return s
class UnimernetModel(object): class UnimernetModel(object):
def __init__(self, weight_dir, cfg_path, _device_='cpu'): def __init__(self, weight_dir, cfg_path, _device_="cpu"):
args = argparse.Namespace(cfg_path=cfg_path, options=None) args = argparse.Namespace(cfg_path=cfg_path, options=None)
cfg = Config(args) cfg = Config(args)
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth") cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
...@@ -62,20 +60,28 @@ class UnimernetModel(object): ...@@ -62,20 +60,28 @@ class UnimernetModel(object):
self.device = _device_ self.device = _device_
self.model.to(_device_) self.model.to(_device_)
self.model.eval() self.model.eval()
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval) vis_processor = load_processor(
self.mfr_transform = transforms.Compose([vis_processor, ]) "formula_image_eval",
cfg.config.datasets.formula_rec_eval.vis_processor.eval,
)
self.mfr_transform = transforms.Compose(
[
vis_processor,
]
)
def predict(self, mfd_res, image): def predict(self, mfd_res, image):
formula_list = [] formula_list = []
mf_image_list = [] mf_image_list = []
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()): for xyxy, conf, cla in zip(
mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy] xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = { new_item = {
'category_id': 13 + int(cla.item()), "category_id": 13 + int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax], "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 2), "score": round(float(conf.item()), 2),
'latex': '', "latex": "",
} }
formula_list.append(new_item) formula_list.append(new_item)
pil_img = Image.fromarray(image) pil_img = Image.fromarray(image)
...@@ -88,11 +94,48 @@ class UnimernetModel(object): ...@@ -88,11 +94,48 @@ class UnimernetModel(object):
for mf_img in dataloader: for mf_img in dataloader:
mf_img = mf_img.to(self.device) mf_img = mf_img.to(self.device)
with torch.no_grad(): with torch.no_grad():
output = self.model.generate({'image': mf_img}) output = self.model.generate({"image": mf_img})
mfr_res.extend(output['pred_str']) mfr_res.extend(output["pred_str"])
for res, latex in zip(formula_list, mfr_res): for res, latex in zip(formula_list, mfr_res):
res['latex'] = latex_rm_whitespace(latex) res["latex"] = latex_rm_whitespace(latex)
return formula_list return formula_list
def batch_predict(
self, images_mfd_res: list, images: list, batch_size: int = 64
) -> list:
images_formula_list = []
mf_image_list = []
backfill_list = []
for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index]
pil_img = Image.fromarray(images[image_index])
formula_list = []
for xyxy, conf, cla in zip(
mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
"category_id": 13 + int(cla.item()),
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
"score": round(float(conf.item()), 2),
"latex": "",
}
formula_list.append(new_item)
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
mf_image_list.append(bbox_img)
images_formula_list.append(formula_list)
backfill_list += formula_list
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
mfr_res = []
for mf_img in dataloader:
mf_img = mf_img.to(self.device)
with torch.no_grad():
output = self.model.generate({"image": mf_img})
mfr_res.extend(output["pred_str"])
for res, latex in zip(backfill_list, mfr_res):
res["latex"] = latex_rm_whitespace(latex)
return images_formula_list
import torch
from loguru import logger from loguru import logger
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \ from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
DocLayoutYOLOModel DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \ from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
...@@ -19,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \ ...@@ -19,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE: if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time) table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER: elif table_model_type == MODEL_NAME.TABLE_MASTER:
...@@ -29,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): ...@@ -29,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
} }
table_model = TableMasterPaddleModel(config) table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE: elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTableModel() table_model = RapidTableModel(ocr_engine)
else: else:
logger.error('table model type not allow') logger.error('table model type not allow')
exit(1) exit(1)
...@@ -38,6 +40,8 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): ...@@ -38,6 +40,8 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
def mfd_model_init(weight, device='cpu'): def mfd_model_init(weight, device='cpu'):
if str(device).startswith("npu"):
device = torch.device(device)
mfd_model = YOLOv8MFDModel(weight, device) mfd_model = YOLOv8MFDModel(weight, device)
return mfd_model return mfd_model
...@@ -53,16 +57,26 @@ def layout_model_init(weight, config_file, device): ...@@ -53,16 +57,26 @@ def layout_model_init(weight, config_file, device):
def doclayout_yolo_model_init(weight, device='cpu'): def doclayout_yolo_model_init(weight, device='cpu'):
if str(device).startswith("npu"):
device = torch.device(device)
model = DocLayoutYOLOModel(weight, device) model = DocLayoutYOLOModel(weight, device)
return model return model
def langdetect_model_init(langdetect_model_weight, device='cpu'):
if str(device).startswith("npu"):
device = torch.device(device)
model = YOLOv11LangDetModel(langdetect_model_weight, device)
return model
def ocr_model_init(show_log: bool = False, def ocr_model_init(show_log: bool = False,
det_db_box_thresh=0.3, det_db_box_thresh=0.3,
lang=None, lang=None,
use_dilation=True, use_dilation=True,
det_db_unclip_ratio=1.8, det_db_unclip_ratio=1.8,
): ):
if lang is not None and lang != '': if lang is not None and lang != '':
model = ModifiedPaddleOCR( model = ModifiedPaddleOCR(
show_log=show_log, show_log=show_log,
...@@ -77,7 +91,6 @@ def ocr_model_init(show_log: bool = False, ...@@ -77,7 +91,6 @@ def ocr_model_init(show_log: bool = False,
det_db_box_thresh=det_db_box_thresh, det_db_box_thresh=det_db_box_thresh,
use_dilation=use_dilation, use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio, det_db_unclip_ratio=det_db_unclip_ratio,
# use_angle_cls=True,
) )
return model return model
...@@ -124,6 +137,9 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -124,6 +137,9 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('doclayout_yolo_weights'), kwargs.get('doclayout_yolo_weights'),
kwargs.get('device') kwargs.get('device')
) )
else:
logger.error('layout model name not allow')
exit(1)
elif model_name == AtomicModel.MFD: elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init( atom_model = mfd_model_init(
kwargs.get('mfd_weights'), kwargs.get('mfd_weights'),
...@@ -146,8 +162,18 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -146,8 +162,18 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_name'), kwargs.get('table_model_name'),
kwargs.get('table_model_path'), kwargs.get('table_model_path'),
kwargs.get('table_max_time'), kwargs.get('table_max_time'),
kwargs.get('device') kwargs.get('device'),
kwargs.get('ocr_engine')
) )
elif model_name == AtomicModel.LangDetect:
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
atom_model = langdetect_model_init(
kwargs.get('langdetect_model_weight'),
kwargs.get('device')
)
else:
logger.error('langdetect model name not allow')
exit(1)
else: else:
logger.error('model name not allow') logger.error('model name not allow')
exit(1) exit(1)
......
...@@ -45,7 +45,7 @@ def clean_vram(device, vram_threshold=8): ...@@ -45,7 +45,7 @@ def clean_vram(device, vram_threshold=8):
total_memory = get_vram(device) total_memory = get_vram(device)
if total_memory and total_memory <= vram_threshold: if total_memory and total_memory <= vram_threshold:
gc_start = time.time() gc_start = time.time()
clean_memory() clean_memory(device)
gc_time = round(time.time() - gc_start, 2) gc_time = round(time.time() - gc_start, 2)
logger.info(f"gc time: {gc_time}") logger.info(f"gc time: {gc_time}")
...@@ -54,4 +54,10 @@ def get_vram(device): ...@@ -54,4 +54,10 @@ def get_vram(device):
if torch.cuda.is_available() and device != 'cpu': if torch.cuda.is_available() and device != 'cpu':
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
return total_memory return total_memory
return None elif str(device).startswith("npu"):
\ No newline at end of file import torch_npu
if torch_npu.npu.is_available():
total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
return total_memory
else:
return None
\ No newline at end of file
...@@ -303,4 +303,54 @@ def calculate_is_angle(poly): ...@@ -303,4 +303,54 @@ def calculate_is_angle(poly):
return False return False
else: else:
# logger.info((p3[1] - p1[1])/height) # logger.info((p3[1] - p1[1])/height)
return True return True
\ No newline at end of file
class ONNXModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_onnx_model(self, **kwargs):
lang = kwargs.get('lang', None)
det_db_box_thresh = kwargs.get('det_db_box_thresh', 0.3)
use_dilation = kwargs.get('use_dilation', True)
det_db_unclip_ratio = kwargs.get('det_db_unclip_ratio', 1.8)
key = (lang, det_db_box_thresh, use_dilation, det_db_unclip_ratio)
if key not in self._models:
self._models[key] = onnx_model_init(key)
return self._models[key]
def onnx_model_init(key):
import importlib.resources
resource_path = importlib.resources.path('rapidocr_onnxruntime.models','')
onnx_model = None
additional_ocr_params = {
"use_onnx": True,
"det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
"rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
"cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
"det_db_box_thresh": key[1],
"use_dilation": key[2],
"det_db_unclip_ratio": key[3],
}
# logger.info(f"additional_ocr_params: {additional_ocr_params}")
if key[0] is not None:
additional_ocr_params["lang"] = key[0]
from paddleocr import PaddleOCR
onnx_model = PaddleOCR(**additional_ocr_params)
if onnx_model is None:
logger.error('model init failed')
exit(1)
else:
return onnx_model
\ No newline at end of file
import copy import copy
import platform
import time import time
import cv2 import cv2
import numpy as np import numpy as np
import torch
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
...@@ -9,12 +11,25 @@ from ppocr.utils.utility import alpha_to_color, binarize_img ...@@ -9,12 +11,25 @@ from ppocr.utils.utility import alpha_to_color, binarize_img
from tools.infer.predict_system import sorted_boxes from tools.infer.predict_system import sorted_boxes
from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img, \
ONNXModelSingleton
logger = get_logger() logger = get_logger()
class ModifiedPaddleOCR(PaddleOCR): class ModifiedPaddleOCR(PaddleOCR):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lang = kwargs.get('lang', 'ch')
# 在cpu架构为arm且不支持cuda时调用onnx、
if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
self.use_onnx = True
onnx_model_manager = ONNXModelSingleton()
self.additional_ocr = onnx_model_manager.get_onnx_model(**kwargs)
else:
self.use_onnx = False
def ocr(self, def ocr(self,
img, img,
det=True, det=True,
...@@ -79,7 +94,10 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -79,7 +94,10 @@ class ModifiedPaddleOCR(PaddleOCR):
ocr_res = [] ocr_res = []
for img in imgs: for img in imgs:
img = preprocess_image(img) img = preprocess_image(img)
dt_boxes, elapse = self.text_detector(img) if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
if dt_boxes is None: if dt_boxes is None:
ocr_res.append(None) ocr_res.append(None)
continue continue
...@@ -106,7 +124,10 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -106,7 +124,10 @@ class ModifiedPaddleOCR(PaddleOCR):
img, cls_res_tmp, elapse = self.text_classifier(img) img, cls_res_tmp, elapse = self.text_classifier(img)
if not rec: if not rec:
cls_res.append(cls_res_tmp) cls_res.append(cls_res_tmp)
rec_res, elapse = self.text_recognizer(img) if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img)
else:
rec_res, elapse = self.text_recognizer(img)
ocr_res.append(rec_res) ocr_res.append(rec_res)
if not rec: if not rec:
return cls_res return cls_res
...@@ -121,7 +142,10 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -121,7 +142,10 @@ class ModifiedPaddleOCR(PaddleOCR):
start = time.time() start = time.time()
ori_im = img.copy() ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img) if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
time_dict['det'] = elapse time_dict['det'] = elapse
if dt_boxes is None: if dt_boxes is None:
...@@ -159,8 +183,10 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -159,8 +183,10 @@ class ModifiedPaddleOCR(PaddleOCR):
time_dict['cls'] = elapse time_dict['cls'] = elapse
logger.debug("cls num : {}, elapsed : {}".format( logger.debug("cls num : {}, elapsed : {}".format(
len(img_crop_list), elapse)) len(img_crop_list), elapse))
if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
else:
rec_res, elapse = self.text_recognizer(img_crop_list)
time_dict['rec'] = elapse time_dict['rec'] = elapse
logger.debug("rec_res num : {}, elapsed : {}".format( logger.debug("rec_res num : {}, elapsed : {}".format(
len(rec_res), elapse)) len(rec_res), elapse))
......
import cv2
import numpy as np import numpy as np
import torch
from loguru import logger
from rapid_table import RapidTable from rapid_table import RapidTable
from rapidocr_paddle import RapidOCR
class RapidTableModel(object): class RapidTableModel(object):
def __init__(self): def __init__(self, ocr_engine):
self.table_model = RapidTable() self.table_model = RapidTable()
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True) # if ocr_engine is None:
# self.ocr_model_name = "RapidOCR"
# if torch.cuda.is_available():
# from rapidocr_paddle import RapidOCR
# self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
# else:
# from rapidocr_onnxruntime import RapidOCR
# self.ocr_engine = RapidOCR()
# else:
# self.ocr_model_name = "PaddleOCR"
# self.ocr_engine = ocr_engine
self.ocr_model_name = "RapidOCR"
if torch.cuda.is_available():
from rapidocr_paddle import RapidOCR
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
else:
from rapidocr_onnxruntime import RapidOCR
self.ocr_engine = RapidOCR()
def predict(self, image): def predict(self, image):
ocr_result, _ = self.ocr_engine(np.asarray(image))
if ocr_result is None: if self.ocr_model_name == "RapidOCR":
ocr_result, _ = self.ocr_engine(np.asarray(image))
elif self.ocr_model_name == "PaddleOCR":
bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
ocr_result = self.ocr_engine.ocr(bgr_image)[0]
if ocr_result:
ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
len(item) == 2 and isinstance(item[1], tuple)]
else:
ocr_result = None
else:
logger.error("OCR model not supported")
ocr_result = None
if ocr_result:
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
return html_code, table_cell_bboxes, elapse
else:
return None, None, None return None, None, None
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
return html_code, table_cell_bboxes, elapse
\ No newline at end of file
from abc import ABC, abstractmethod
from typing import Callable
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.operators.pipes import PipeResult
class InferenceResultBase(ABC):
@abstractmethod
def __init__(self, inference_results: list, dataset: Dataset):
"""Initialized method.
Args:
inference_results (list): the inference result generated by model
dataset (Dataset): the dataset related with model inference result
"""
pass
@abstractmethod
def draw_model(self, file_path: str) -> None:
"""Draw model inference result.
Args:
file_path (str): the output file path
"""
pass
@abstractmethod
def dump_model(self, writer: DataWriter, file_path: str):
"""Dump model inference result to file.
Args:
writer (DataWriter): writer handle
file_path (str): the location of target file
"""
pass
@abstractmethod
def get_infer_res(self):
"""Get the inference result.
Returns:
list: the inference result generated by model
"""
pass
@abstractmethod
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(inference_result, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
pass
def pipe_txt_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result, Extract the text using the
third library, such as `pymupdf`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pass
@abstractmethod
def pipe_ocr_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
pass
...@@ -7,13 +7,11 @@ from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT ...@@ -7,13 +7,11 @@ from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.data_reader_writer import DataWriter from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset from magic_pdf.data.dataset import Dataset
from magic_pdf.filter import classify
from magic_pdf.libs.draw_bbox import draw_model_bbox from magic_pdf.libs.draw_bbox import draw_model_bbox
from magic_pdf.libs.version import __version__ from magic_pdf.libs.version import __version__
from magic_pdf.model import InferenceResultBase from magic_pdf.operators.pipes import PipeResult
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
from magic_pdf.pipe.operators import PipeResult from magic_pdf.operators import InferenceResultBase
class InferenceResult(InferenceResultBase): class InferenceResult(InferenceResultBase):
def __init__(self, inference_results: list, dataset: Dataset): def __init__(self, inference_results: list, dataset: Dataset):
...@@ -71,40 +69,6 @@ class InferenceResult(InferenceResultBase): ...@@ -71,40 +69,6 @@ class InferenceResult(InferenceResultBase):
""" """
return proc(copy.deepcopy(self._infer_res), *args, **kwargs) return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
def pipe_auto_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result.
step1: classify the dataset type
step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pdf_proc_method = classify(self._dataset.data_bits())
if pdf_proc_method == SupportedPdfParseMethod.TXT:
return self.pipe_txt_mode(
imageWriter, start_page_id, end_page_id, debug_mode, lang
)
else:
return self.pipe_ocr_mode(
imageWriter, start_page_id, end_page_id, debug_mode, lang
)
def pipe_txt_mode( def pipe_txt_mode(
self, self,
imageWriter: DataWriter, imageWriter: DataWriter,
......
import copy
import json import json
import os import os
from typing import Callable from typing import Callable
import copy
from magic_pdf.config.make_content_config import DropMode, MakeMode from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter from magic_pdf.data.data_reader_writer import DataWriter
...@@ -23,12 +23,34 @@ class PipeResult: ...@@ -23,12 +23,34 @@ class PipeResult:
self._pipe_res = pipe_res self._pipe_res = pipe_res
self._dataset = dataset self._dataset = dataset
def get_markdown(
self,
img_dir_or_bucket_prefix: str,
drop_mode=DropMode.NONE,
md_make_mode=MakeMode.MM_MD,
) -> str:
"""Get markdown content.
Args:
img_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.NONE.
md_make_mode (str, optional): The content Type of Markdown be made. Defaults to MakeMode.MM_MD.
Returns:
str: return markdown content
"""
pdf_info_list = self._pipe_res['pdf_info']
md_content = union_make(
pdf_info_list, md_make_mode, drop_mode, img_dir_or_bucket_prefix
)
return md_content
def dump_md( def dump_md(
self, self,
writer: DataWriter, writer: DataWriter,
file_path: str, file_path: str,
img_dir_or_bucket_prefix: str, img_dir_or_bucket_prefix: str,
drop_mode=DropMode.WHOLE_PDF, drop_mode=DropMode.NONE,
md_make_mode=MakeMode.MM_MD, md_make_mode=MakeMode.MM_MD,
): ):
"""Dump The Markdown. """Dump The Markdown.
...@@ -37,36 +59,68 @@ class PipeResult: ...@@ -37,36 +59,68 @@ class PipeResult:
writer (DataWriter): File writer handle writer (DataWriter): File writer handle
file_path (str): The file location of markdown file_path (str): The file location of markdown
img_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure img_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.WHOLE_PDF. drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.NONE.
md_make_mode (str, optional): The content Type of Markdown be made. Defaults to MakeMode.MM_MD. md_make_mode (str, optional): The content Type of Markdown be made. Defaults to MakeMode.MM_MD.
""" """
pdf_info_list = self._pipe_res['pdf_info']
md_content = union_make( md_content = self.get_markdown(
pdf_info_list, md_make_mode, drop_mode, img_dir_or_bucket_prefix img_dir_or_bucket_prefix, drop_mode=drop_mode, md_make_mode=md_make_mode
) )
writer.write_string(file_path, md_content) writer.write_string(file_path, md_content)
def dump_content_list( def get_content_list(
self, writer: DataWriter, file_path: str, image_dir_or_bucket_prefix: str self,
): image_dir_or_bucket_prefix: str,
"""Dump Content List. drop_mode=DropMode.NONE,
) -> str:
"""Get Content List.
Args: Args:
writer (DataWriter): File writer handle
file_path (str): The file location of content list
image_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure image_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.NONE.
Returns:
str: content list content
""" """
pdf_info_list = self._pipe_res['pdf_info'] pdf_info_list = self._pipe_res['pdf_info']
content_list = union_make( content_list = union_make(
pdf_info_list, pdf_info_list,
MakeMode.STANDARD_FORMAT, MakeMode.STANDARD_FORMAT,
DropMode.NONE, drop_mode,
image_dir_or_bucket_prefix, image_dir_or_bucket_prefix,
) )
return content_list
def dump_content_list(
self,
writer: DataWriter,
file_path: str,
image_dir_or_bucket_prefix: str,
drop_mode=DropMode.NONE,
):
"""Dump Content List.
Args:
writer (DataWriter): File writer handle
file_path (str): The file location of content list
image_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.NONE.
"""
content_list = self.get_content_list(
image_dir_or_bucket_prefix, drop_mode=drop_mode,
)
writer.write_string( writer.write_string(
file_path, json.dumps(content_list, ensure_ascii=False, indent=4) file_path, json.dumps(content_list, ensure_ascii=False, indent=4)
) )
def get_middle_json(self) -> str:
"""Get middle json.
Returns:
str: The content of middle json
"""
return json.dumps(self._pipe_res, ensure_ascii=False, indent=4)
def dump_middle_json(self, writer: DataWriter, file_path: str): def dump_middle_json(self, writer: DataWriter, file_path: str):
"""Dump the result of pipeline. """Dump the result of pipeline.
...@@ -74,9 +128,8 @@ class PipeResult: ...@@ -74,9 +128,8 @@ class PipeResult:
writer (DataWriter): File writer handler writer (DataWriter): File writer handler
file_path (str): The file location of middle json file_path (str): The file location of middle json
""" """
writer.write_string( middle_json = self.get_middle_json()
file_path, json.dumps(self._pipe_res, ensure_ascii=False, indent=4) writer.write_string(file_path, middle_json)
)
def draw_layout(self, file_path: str) -> None: def draw_layout(self, file_path: str) -> None:
"""Draw the layout. """Draw the layout.
...@@ -123,7 +176,7 @@ class PipeResult: ...@@ -123,7 +176,7 @@ class PipeResult:
Returns: Returns:
str: compress the pipeline result and return str: compress the pipeline result and return
""" """
return JsonCompressor.compress_json(self.pdf_mid_data) return JsonCompressor.compress_json(self._pipe_res)
def apply(self, proc: Callable, *args, **kwargs): def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which. """Apply callable method which.
......
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import Dataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
def parse_pdf_by_ocr(dataset: Dataset,
model_list,
imageWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
):
return pdf_parse_union(model_list,
dataset,
imageWriter,
SupportedPdfParseMethod.OCR,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
lang=lang,
)
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import Dataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
def parse_pdf_by_txt(
dataset: Dataset,
model_list,
imageWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
):
return pdf_parse_union(model_list,
dataset,
imageWriter,
SupportedPdfParseMethod.TXT,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
lang=lang,
)
import copy import copy
import os import os
import re
import statistics import statistics
import time import time
from typing import List from typing import List
...@@ -13,11 +14,12 @@ from magic_pdf.config.ocr_content_type import BlockType, ContentType ...@@ -13,11 +14,12 @@ from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.data.dataset import Dataset, PageableData from magic_pdf.data.dataset import Dataset, PageableData
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
from magic_pdf.libs.convert_utils import dict_to_list from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.hash_utils import compute_md5 from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
from magic_pdf.model.magic_model import MagicModel from magic_pdf.model.magic_model import MagicModel
from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text, llm_aided_title
try: try:
import torchtext import torchtext
...@@ -28,15 +30,15 @@ except ImportError: ...@@ -28,15 +30,15 @@ except ImportError:
pass pass
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.para.para_split_v3 import para_split from magic_pdf.post_proc.para_split_v3 import para_split
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2 from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2 from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2
from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans_v2, fix_discarded_block from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans_v2, fix_discarded_block
from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans, remove_overlaps_min_spans from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans, \
remove_overlaps_min_spans, check_chars_is_overlap_in_span
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
def __replace_STX_ETX(text_str: str): def __replace_STX_ETX(text_str: str):
...@@ -63,11 +65,22 @@ def __replace_0xfffd(text_str: str): ...@@ -63,11 +65,22 @@ def __replace_0xfffd(text_str: str):
return s return s
return text_str return text_str
# 连写字符拆分
def __replace_ligatures(text: str):
ligatures = {
'fi': 'fi', 'fl': 'fl', 'ff': 'ff', 'ffi': 'ffi', 'ffl': 'ffl', 'ſt': 'ft', 'st': 'st'
}
return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
def chars_to_content(span): def chars_to_content(span):
# 检查span中的char是否为空 # 检查span中的char是否为空
if len(span['chars']) == 0: if len(span['chars']) == 0:
pass pass
# span['content'] = '' # span['content'] = ''
elif check_chars_is_overlap_in_span(span['chars']):
pass
else: else:
# 先给chars按char['bbox']的中心点的x坐标排序 # 先给chars按char['bbox']的中心点的x坐标排序
span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2) span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2)
...@@ -78,11 +91,16 @@ def chars_to_content(span): ...@@ -78,11 +91,16 @@ def chars_to_content(span):
content = '' content = ''
for char in span['chars']: for char in span['chars']:
# 如果下一个char的x0和上一个char的x1距离超过一个字符宽度,则需要在中间插入一个空格
if char['bbox'][0] - span['chars'][span['chars'].index(char) - 1]['bbox'][2] > char_avg_width:
content += ' '
content += char['c']
# 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格
char1 = char
char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None
if char2 and char2['bbox'][0] - char1['bbox'][2] > char_avg_width * 0.25 and char['c'] != ' ' and char2['c'] != ' ':
content += f"{char['c']} "
else:
content += char['c']
content = __replace_ligatures(content)
span['content'] = __replace_0xfffd(content) span['content'] = __replace_0xfffd(content)
del span['chars'] del span['chars']
...@@ -98,6 +116,10 @@ def fill_char_in_spans(spans, all_chars): ...@@ -98,6 +116,10 @@ def fill_char_in_spans(spans, all_chars):
spans = sorted(spans, key=lambda x: x['bbox'][1]) spans = sorted(spans, key=lambda x: x['bbox'][1])
for char in all_chars: for char in all_chars:
# 跳过非法bbox的char
x1, y1, x2, y2 = char['bbox']
if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
continue
for span in spans: for span in spans:
if calculate_char_in_span(char['bbox'], span['bbox'], char['c']): if calculate_char_in_span(char['bbox'], span['bbox'], char['c']):
span['chars'].append(char) span['chars'].append(char)
...@@ -152,14 +174,16 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33): ...@@ -152,14 +174,16 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang): def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
# cid用0xfffd表示,连字符拆开
# text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks'] # cid用0xfffd表示,连字符不拆开
text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
all_pymu_chars = [] all_pymu_chars = []
for block in text_blocks_raw: for block in text_blocks_raw:
for line in block['lines']: for line in block['lines']:
cosine, sine = line['dir'] cosine, sine = line['dir']
if abs (cosine) < 0.9 or abs(sine) > 0.1: if abs(cosine) < 0.9 or abs(sine) > 0.1:
continue continue
for span in line['spans']: for span in line['spans']:
all_pymu_chars.extend(span['chars']) all_pymu_chars.extend(span['chars'])
...@@ -255,19 +279,23 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang ...@@ -255,19 +279,23 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
return spans return spans
def replace_text_span(pymu_spans, ocr_spans):
return list(filter(lambda x: x['type'] != ContentType.Text, ocr_spans)) + pymu_spans
def model_init(model_name: str): def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification from transformers import LayoutLMv3ForTokenClassification
device = get_device()
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device('cuda') device = torch.device('cuda')
if torch.cuda.is_bf16_supported(): if torch.cuda.is_bf16_supported():
supports_bfloat16 = True supports_bfloat16 = True
else: else:
supports_bfloat16 = False supports_bfloat16 = False
elif str(device).startswith("npu"):
import torch_npu
if torch_npu.npu.is_available():
device = torch.device('npu')
supports_bfloat16 = False
else:
device = torch.device('cpu')
supports_bfloat16 = False
else: else:
device = torch.device('cpu') device = torch.device('cpu')
supports_bfloat16 = False supports_bfloat16 = False
...@@ -345,6 +373,8 @@ def cal_block_index(fix_blocks, sorted_bboxes): ...@@ -345,6 +373,8 @@ def cal_block_index(fix_blocks, sorted_bboxes):
# 使用xycut排序 # 使用xycut排序
block_bboxes = [] block_bboxes = []
for block in fix_blocks: for block in fix_blocks:
# 如果block['bbox']任意值小于0,将其置为0
block['bbox'] = [max(0, x) for x in block['bbox']]
block_bboxes.append(block['bbox']) block_bboxes.append(block['bbox'])
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填 # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
...@@ -738,6 +768,11 @@ def parse_page_core( ...@@ -738,6 +768,11 @@ def parse_page_core(
"""重排block""" """重排block"""
sorted_blocks = sorted(fix_blocks, key=lambda b: b['index']) sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
"""block内重排(img和table的block内多个caption或footnote的排序)"""
for block in sorted_blocks:
if block['type'] in [BlockType.Image, BlockType.Table]:
block['blocks'] = sorted(block['blocks'], key=lambda b: b['index'])
"""获取QA需要外置的list""" """获取QA需要外置的list"""
images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks) images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks)
...@@ -819,13 +854,29 @@ def pdf_parse_union( ...@@ -819,13 +854,29 @@ def pdf_parse_union(
"""分段""" """分段"""
para_split(pdf_info_dict) para_split(pdf_info_dict)
"""llm优化"""
llm_aided_config = get_llm_aided_config()
if llm_aided_config is not None:
"""公式优化"""
formula_aided_config = llm_aided_config.get('formula_aided', None)
if formula_aided_config is not None:
llm_aided_formula(pdf_info_dict, formula_aided_config)
"""文本优化"""
text_aided_config = llm_aided_config.get('text_aided', None)
if text_aided_config is not None:
llm_aided_text(pdf_info_dict, text_aided_config)
"""标题优化"""
title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None:
llm_aided_title(pdf_info_dict, title_aided_config)
"""dict转list""" """dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict) pdf_info_list = dict_to_list(pdf_info_dict)
new_pdf_info_dict = { new_pdf_info_dict = {
'pdf_info': pdf_info_list, 'pdf_info': pdf_info_list,
} }
clean_memory() clean_memory(get_device())
return new_pdf_info_dict return new_pdf_info_dict
......
from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe):
def __init__(
self,
dataset: Dataset,
model_list: list,
image_writer: DataWriter,
is_debug: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
super().__init__(
dataset,
model_list,
image_writer,
is_debug,
start_page_id,
end_page_id,
lang,
layout_model,
formula_enable,
table_enable,
)
def pipe_classify(self):
pass
def pipe_analyze(self):
self.infer_res = doc_analyze(
self.dataset,
ocr=True,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
def pipe_parse(self):
self.pdf_mid_data = parse_ocr_pdf(
self.dataset,
self.infer_res,
self.image_writer,
is_debug=self.is_debug,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info('ocr_pipe mk content list finished')
return result
def pipe_mk_markdown(
self,
img_parent_path: str,
drop_mode=DropMode.WHOLE_PDF,
md_make_mode=MakeMode.MM_MD,
):
result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
logger.info(f'ocr_pipe mk {md_make_mode} finished')
return result
from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_txt_pdf
class TXTPipe(AbsPipe):
def __init__(self, dataset: Dataset, model_list: list, image_writer: DataWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
super().__init__(dataset, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable)
def pipe_classify(self):
pass
def pipe_analyze(self):
self.model_list = doc_analyze(self.dataset, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self):
self.pdf_mid_data = parse_txt_pdf(self.dataset, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info('txt_pipe mk content list finished')
return result
def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
logger.info(f'txt_pipe mk {md_make_mode} finished')
return result
import json
from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.commons import join_path
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_ocr_pdf, parse_union_pdf
class UNIPipe(AbsPipe):
def __init__(
self,
dataset: Dataset,
jso_useful_key: dict,
image_writer: DataWriter,
is_debug: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
self.pdf_type = jso_useful_key['_pdf_type']
super().__init__(
dataset,
jso_useful_key['model_list'],
image_writer,
is_debug,
start_page_id,
end_page_id,
lang,
layout_model,
formula_enable,
table_enable,
)
if len(self.model_list) == 0:
self.input_model_is_empty = True
else:
self.input_model_is_empty = False
def pipe_classify(self):
self.pdf_type = AbsPipe.classify(self.pdf_bytes)
def pipe_analyze(self):
if self.pdf_type == self.PIP_TXT:
self.model_list = doc_analyze(
self.dataset,
ocr=False,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
elif self.pdf_type == self.PIP_OCR:
self.model_list = doc_analyze(
self.dataset,
ocr=True,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
def pipe_parse(self):
if self.pdf_type == self.PIP_TXT:
self.pdf_mid_data = parse_union_pdf(
self.dataset,
self.model_list,
self.image_writer,
is_debug=self.is_debug,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
elif self.pdf_type == self.PIP_OCR:
self.pdf_mid_data = parse_ocr_pdf(
self.dataset,
self.model_list,
self.image_writer,
is_debug=self.is_debug,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
)
def pipe_mk_uni_format(
self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON
):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info('uni_pipe mk content list finished')
return result
def pipe_mk_markdown(
self,
img_parent_path: str,
drop_mode=DropMode.WHOLE_PDF,
md_make_mode=MakeMode.MM_MD,
):
result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
logger.info(f'uni_pipe mk {md_make_mode} finished')
return result
if __name__ == '__main__':
# 测试
from magic_pdf.data.data_reader_writer import DataReader
drw = DataReader(r'D:/project/20231108code-clean')
pdf_file_path = r'linshixuqiu\19983-00.pdf'
model_file_path = r'linshixuqiu\19983-00.json'
pdf_bytes = drw.read(pdf_file_path)
model_json_txt = drw.read(model_file_path).decode()
model_list = json.loads(model_json_txt)
write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00'
img_bucket_path = 'imgs'
img_writer = DataWriter(join_path(write_path, img_bucket_path))
# pdf_type = UNIPipe.classify(pdf_bytes)
# jso_useful_key = {
# "_pdf_type": pdf_type,
# "model_list": model_list
# }
jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
pipe = UNIPipe(pdf_bytes, jso_useful_key, img_writer)
pipe.pipe_classify()
pipe.pipe_parse()
md_content = pipe.pipe_mk_markdown(img_bucket_path)
content_list = pipe.pipe_mk_uni_format(img_bucket_path)
md_writer = DataWriter(write_path)
md_writer.write_string('19983-00.md', md_content)
md_writer.write_string(
'19983-00.json', json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4)
)
md_writer.write_string('19983-00.txt', str(content_list))
# Copyright (c) Opendatalab. All rights reserved.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment