Unverified Commit 4bb54393 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1427 from opendatalab/release-1.0.0

Release 1.0.0
parents 04f084ac 1c9f9942
......@@ -8,14 +8,51 @@ class DocLayoutYOLOModel(object):
def predict(self, image):
layout_res = []
doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(),
doclayout_yolo_res.boxes.cls.cpu()):
doclayout_yolo_res = self.model.predict(
image, imgsz=1024, conf=0.25, iou=0.45, verbose=False, device=self.device
)[0]
for xyxy, conf, cla in zip(
doclayout_yolo_res.boxes.xyxy.cpu(),
doclayout_yolo_res.boxes.conf.cpu(),
doclayout_yolo_res.boxes.cls.cpu(),
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 3),
"category_id": int(cla.item()),
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
"score": round(float(conf.item()), 3),
}
layout_res.append(new_item)
return layout_res
\ No newline at end of file
return layout_res
def batch_predict(self, images: list, batch_size: int) -> list:
images_layout_res = []
for index in range(0, len(images), batch_size):
doclayout_yolo_res = [
image_res.cpu()
for image_res in self.model.predict(
images[index : index + batch_size],
imgsz=1024,
conf=0.25,
iou=0.45,
verbose=False,
device=self.device,
)
]
for image_res in doclayout_yolo_res:
layout_res = []
for xyxy, conf, cla in zip(
image_res.boxes.xyxy,
image_res.boxes.conf,
image_res.boxes.cls,
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
"category_id": int(cla.item()),
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
"score": round(float(conf.item()), 3),
}
layout_res.append(new_item)
images_layout_res.append(layout_res)
return images_layout_res
......@@ -2,11 +2,30 @@ from ultralytics import YOLO
class YOLOv8MFDModel(object):
def __init__(self, weight, device='cpu'):
def __init__(self, weight, device="cpu"):
self.mfd_model = YOLO(weight)
self.device = device
def predict(self, image):
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
mfd_res = self.mfd_model.predict(
image, imgsz=1888, conf=0.25, iou=0.45, verbose=False, device=self.device
)[0]
return mfd_res
def batch_predict(self, images: list, batch_size: int) -> list:
images_mfd_res = []
for index in range(0, len(images), batch_size):
mfd_res = [
image_res.cpu()
for image_res in self.mfd_model.predict(
images[index : index + batch_size],
imgsz=1888,
conf=0.25,
iou=0.45,
verbose=False,
device=self.device,
)
]
for image_res in mfd_res:
images_mfd_res.append(image_res)
return images_mfd_res
import os
import argparse
import os
import re
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import unimernet.tasks as tasks
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor
......@@ -31,27 +31,25 @@ class MathDataset(Dataset):
def latex_rm_whitespace(s: str):
"""Remove unnecessary whitespace from LaTeX code.
"""
text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
letter = '[a-zA-Z]'
noletter = '[\W_^\d]'
names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
"""Remove unnecessary whitespace from LaTeX code."""
text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
letter = "[a-zA-Z]"
noletter = "[\W_^\d]"
names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
news = s
while True:
s = news
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
if news == s:
break
return s
class UnimernetModel(object):
def __init__(self, weight_dir, cfg_path, _device_='cpu'):
def __init__(self, weight_dir, cfg_path, _device_="cpu"):
args = argparse.Namespace(cfg_path=cfg_path, options=None)
cfg = Config(args)
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
......@@ -62,20 +60,28 @@ class UnimernetModel(object):
self.device = _device_
self.model.to(_device_)
self.model.eval()
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
self.mfr_transform = transforms.Compose([vis_processor, ])
vis_processor = load_processor(
"formula_image_eval",
cfg.config.datasets.formula_rec_eval.vis_processor.eval,
)
self.mfr_transform = transforms.Compose(
[
vis_processor,
]
)
def predict(self, mfd_res, image):
formula_list = []
mf_image_list = []
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
for xyxy, conf, cla in zip(
mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': 13 + int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 2),
'latex': '',
"category_id": 13 + int(cla.item()),
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
"score": round(float(conf.item()), 2),
"latex": "",
}
formula_list.append(new_item)
pil_img = Image.fromarray(image)
......@@ -88,11 +94,48 @@ class UnimernetModel(object):
for mf_img in dataloader:
mf_img = mf_img.to(self.device)
with torch.no_grad():
output = self.model.generate({'image': mf_img})
mfr_res.extend(output['pred_str'])
output = self.model.generate({"image": mf_img})
mfr_res.extend(output["pred_str"])
for res, latex in zip(formula_list, mfr_res):
res['latex'] = latex_rm_whitespace(latex)
res["latex"] = latex_rm_whitespace(latex)
return formula_list
def batch_predict(
self, images_mfd_res: list, images: list, batch_size: int = 64
) -> list:
images_formula_list = []
mf_image_list = []
backfill_list = []
for image_index in range(len(images_mfd_res)):
mfd_res = images_mfd_res[image_index]
pil_img = Image.fromarray(images[image_index])
formula_list = []
for xyxy, conf, cla in zip(
mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
"category_id": 13 + int(cla.item()),
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
"score": round(float(conf.item()), 2),
"latex": "",
}
formula_list.append(new_item)
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
mf_image_list.append(bbox_img)
images_formula_list.append(formula_list)
backfill_list += formula_list
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
mfr_res = []
for mf_img in dataloader:
mf_img = mf_img.to(self.device)
with torch.no_grad():
output = self.model.generate({"image": mf_img})
mfr_res.extend(output["pred_str"])
for res, latex in zip(backfill_list, mfr_res):
res["latex"] = latex_rm_whitespace(latex)
return images_formula_list
import torch
from loguru import logger
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
......@@ -19,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER:
......@@ -29,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
}
table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTableModel()
table_model = RapidTableModel(ocr_engine)
else:
logger.error('table model type not allow')
exit(1)
......@@ -38,6 +40,8 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
def mfd_model_init(weight, device='cpu'):
if str(device).startswith("npu"):
device = torch.device(device)
mfd_model = YOLOv8MFDModel(weight, device)
return mfd_model
......@@ -53,16 +57,26 @@ def layout_model_init(weight, config_file, device):
def doclayout_yolo_model_init(weight, device='cpu'):
if str(device).startswith("npu"):
device = torch.device(device)
model = DocLayoutYOLOModel(weight, device)
return model
def langdetect_model_init(langdetect_model_weight, device='cpu'):
if str(device).startswith("npu"):
device = torch.device(device)
model = YOLOv11LangDetModel(langdetect_model_weight, device)
return model
def ocr_model_init(show_log: bool = False,
det_db_box_thresh=0.3,
lang=None,
use_dilation=True,
det_db_unclip_ratio=1.8,
):
if lang is not None and lang != '':
model = ModifiedPaddleOCR(
show_log=show_log,
......@@ -77,7 +91,6 @@ def ocr_model_init(show_log: bool = False,
det_db_box_thresh=det_db_box_thresh,
use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio,
# use_angle_cls=True,
)
return model
......@@ -124,6 +137,9 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('doclayout_yolo_weights'),
kwargs.get('device')
)
else:
logger.error('layout model name not allow')
exit(1)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get('mfd_weights'),
......@@ -146,8 +162,18 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_name'),
kwargs.get('table_model_path'),
kwargs.get('table_max_time'),
kwargs.get('device')
kwargs.get('device'),
kwargs.get('ocr_engine')
)
elif model_name == AtomicModel.LangDetect:
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
atom_model = langdetect_model_init(
kwargs.get('langdetect_model_weight'),
kwargs.get('device')
)
else:
logger.error('langdetect model name not allow')
exit(1)
else:
logger.error('model name not allow')
exit(1)
......
......@@ -45,7 +45,7 @@ def clean_vram(device, vram_threshold=8):
total_memory = get_vram(device)
if total_memory and total_memory <= vram_threshold:
gc_start = time.time()
clean_memory()
clean_memory(device)
gc_time = round(time.time() - gc_start, 2)
logger.info(f"gc time: {gc_time}")
......@@ -54,4 +54,10 @@ def get_vram(device):
if torch.cuda.is_available() and device != 'cpu':
total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
return total_memory
return None
\ No newline at end of file
elif str(device).startswith("npu"):
import torch_npu
if torch_npu.npu.is_available():
total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
return total_memory
else:
return None
\ No newline at end of file
......@@ -303,4 +303,54 @@ def calculate_is_angle(poly):
return False
else:
# logger.info((p3[1] - p1[1])/height)
return True
\ No newline at end of file
return True
class ONNXModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_onnx_model(self, **kwargs):
lang = kwargs.get('lang', None)
det_db_box_thresh = kwargs.get('det_db_box_thresh', 0.3)
use_dilation = kwargs.get('use_dilation', True)
det_db_unclip_ratio = kwargs.get('det_db_unclip_ratio', 1.8)
key = (lang, det_db_box_thresh, use_dilation, det_db_unclip_ratio)
if key not in self._models:
self._models[key] = onnx_model_init(key)
return self._models[key]
def onnx_model_init(key):
import importlib.resources
resource_path = importlib.resources.path('rapidocr_onnxruntime.models','')
onnx_model = None
additional_ocr_params = {
"use_onnx": True,
"det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
"rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
"cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
"det_db_box_thresh": key[1],
"use_dilation": key[2],
"det_db_unclip_ratio": key[3],
}
# logger.info(f"additional_ocr_params: {additional_ocr_params}")
if key[0] is not None:
additional_ocr_params["lang"] = key[0]
from paddleocr import PaddleOCR
onnx_model = PaddleOCR(**additional_ocr_params)
if onnx_model is None:
logger.error('model init failed')
exit(1)
else:
return onnx_model
\ No newline at end of file
import copy
import platform
import time
import cv2
import numpy as np
import torch
from paddleocr import PaddleOCR
from ppocr.utils.logging import get_logger
......@@ -9,12 +11,25 @@ from ppocr.utils.utility import alpha_to_color, binarize_img
from tools.infer.predict_system import sorted_boxes
from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img, \
ONNXModelSingleton
logger = get_logger()
class ModifiedPaddleOCR(PaddleOCR):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lang = kwargs.get('lang', 'ch')
# 在cpu架构为arm且不支持cuda时调用onnx、
if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
self.use_onnx = True
onnx_model_manager = ONNXModelSingleton()
self.additional_ocr = onnx_model_manager.get_onnx_model(**kwargs)
else:
self.use_onnx = False
def ocr(self,
img,
det=True,
......@@ -79,7 +94,10 @@ class ModifiedPaddleOCR(PaddleOCR):
ocr_res = []
for img in imgs:
img = preprocess_image(img)
dt_boxes, elapse = self.text_detector(img)
if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
if dt_boxes is None:
ocr_res.append(None)
continue
......@@ -106,7 +124,10 @@ class ModifiedPaddleOCR(PaddleOCR):
img, cls_res_tmp, elapse = self.text_classifier(img)
if not rec:
cls_res.append(cls_res_tmp)
rec_res, elapse = self.text_recognizer(img)
if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img)
else:
rec_res, elapse = self.text_recognizer(img)
ocr_res.append(rec_res)
if not rec:
return cls_res
......@@ -121,7 +142,10 @@ class ModifiedPaddleOCR(PaddleOCR):
start = time.time()
ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img)
if self.lang in ['ch'] and self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
time_dict['det'] = elapse
if dt_boxes is None:
......@@ -159,8 +183,10 @@ class ModifiedPaddleOCR(PaddleOCR):
time_dict['cls'] = elapse
logger.debug("cls num : {}, elapsed : {}".format(
len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list)
if self.lang in ['ch'] and self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
else:
rec_res, elapse = self.text_recognizer(img_crop_list)
time_dict['rec'] = elapse
logger.debug("rec_res num : {}, elapsed : {}".format(
len(rec_res), elapse))
......
import cv2
import numpy as np
import torch
from loguru import logger
from rapid_table import RapidTable
from rapidocr_paddle import RapidOCR
class RapidTableModel(object):
def __init__(self):
def __init__(self, ocr_engine):
self.table_model = RapidTable()
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
# if ocr_engine is None:
# self.ocr_model_name = "RapidOCR"
# if torch.cuda.is_available():
# from rapidocr_paddle import RapidOCR
# self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
# else:
# from rapidocr_onnxruntime import RapidOCR
# self.ocr_engine = RapidOCR()
# else:
# self.ocr_model_name = "PaddleOCR"
# self.ocr_engine = ocr_engine
self.ocr_model_name = "RapidOCR"
if torch.cuda.is_available():
from rapidocr_paddle import RapidOCR
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
else:
from rapidocr_onnxruntime import RapidOCR
self.ocr_engine = RapidOCR()
def predict(self, image):
ocr_result, _ = self.ocr_engine(np.asarray(image))
if ocr_result is None:
if self.ocr_model_name == "RapidOCR":
ocr_result, _ = self.ocr_engine(np.asarray(image))
elif self.ocr_model_name == "PaddleOCR":
bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
ocr_result = self.ocr_engine.ocr(bgr_image)[0]
if ocr_result:
ocr_result = [[item[0], item[1][0], item[1][1]] for item in ocr_result if
len(item) == 2 and isinstance(item[1], tuple)]
else:
ocr_result = None
else:
logger.error("OCR model not supported")
ocr_result = None
if ocr_result:
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
return html_code, table_cell_bboxes, elapse
else:
return None, None, None
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
return html_code, table_cell_bboxes, elapse
\ No newline at end of file
from abc import ABC, abstractmethod
from typing import Callable
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.operators.pipes import PipeResult
class InferenceResultBase(ABC):
@abstractmethod
def __init__(self, inference_results: list, dataset: Dataset):
"""Initialized method.
Args:
inference_results (list): the inference result generated by model
dataset (Dataset): the dataset related with model inference result
"""
pass
@abstractmethod
def draw_model(self, file_path: str) -> None:
"""Draw model inference result.
Args:
file_path (str): the output file path
"""
pass
@abstractmethod
def dump_model(self, writer: DataWriter, file_path: str):
"""Dump model inference result to file.
Args:
writer (DataWriter): writer handle
file_path (str): the location of target file
"""
pass
@abstractmethod
def get_infer_res(self):
"""Get the inference result.
Returns:
list: the inference result generated by model
"""
pass
@abstractmethod
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(inference_result, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
pass
def pipe_txt_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result, Extract the text using the
third library, such as `pymupdf`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pass
@abstractmethod
def pipe_ocr_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
pass
......@@ -7,13 +7,11 @@ from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.filter import classify
from magic_pdf.libs.draw_bbox import draw_model_bbox
from magic_pdf.libs.version import __version__
from magic_pdf.model import InferenceResultBase
from magic_pdf.operators.pipes import PipeResult
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
from magic_pdf.pipe.operators import PipeResult
from magic_pdf.operators import InferenceResultBase
class InferenceResult(InferenceResultBase):
def __init__(self, inference_results: list, dataset: Dataset):
......@@ -71,40 +69,6 @@ class InferenceResult(InferenceResultBase):
"""
return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
def pipe_auto_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result.
step1: classify the dataset type
step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pdf_proc_method = classify(self._dataset.data_bits())
if pdf_proc_method == SupportedPdfParseMethod.TXT:
return self.pipe_txt_mode(
imageWriter, start_page_id, end_page_id, debug_mode, lang
)
else:
return self.pipe_ocr_mode(
imageWriter, start_page_id, end_page_id, debug_mode, lang
)
def pipe_txt_mode(
self,
imageWriter: DataWriter,
......
import copy
import json
import os
from typing import Callable
import copy
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
......@@ -23,12 +23,34 @@ class PipeResult:
self._pipe_res = pipe_res
self._dataset = dataset
def get_markdown(
self,
img_dir_or_bucket_prefix: str,
drop_mode=DropMode.NONE,
md_make_mode=MakeMode.MM_MD,
) -> str:
"""Get markdown content.
Args:
img_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.NONE.
md_make_mode (str, optional): The content Type of Markdown be made. Defaults to MakeMode.MM_MD.
Returns:
str: return markdown content
"""
pdf_info_list = self._pipe_res['pdf_info']
md_content = union_make(
pdf_info_list, md_make_mode, drop_mode, img_dir_or_bucket_prefix
)
return md_content
def dump_md(
self,
writer: DataWriter,
file_path: str,
img_dir_or_bucket_prefix: str,
drop_mode=DropMode.WHOLE_PDF,
drop_mode=DropMode.NONE,
md_make_mode=MakeMode.MM_MD,
):
"""Dump The Markdown.
......@@ -37,36 +59,68 @@ class PipeResult:
writer (DataWriter): File writer handle
file_path (str): The file location of markdown
img_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.WHOLE_PDF.
drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.NONE.
md_make_mode (str, optional): The content Type of Markdown be made. Defaults to MakeMode.MM_MD.
"""
pdf_info_list = self._pipe_res['pdf_info']
md_content = union_make(
pdf_info_list, md_make_mode, drop_mode, img_dir_or_bucket_prefix
md_content = self.get_markdown(
img_dir_or_bucket_prefix, drop_mode=drop_mode, md_make_mode=md_make_mode
)
writer.write_string(file_path, md_content)
def dump_content_list(
self, writer: DataWriter, file_path: str, image_dir_or_bucket_prefix: str
):
"""Dump Content List.
def get_content_list(
self,
image_dir_or_bucket_prefix: str,
drop_mode=DropMode.NONE,
) -> str:
"""Get Content List.
Args:
writer (DataWriter): File writer handle
file_path (str): The file location of content list
image_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.NONE.
Returns:
str: content list content
"""
pdf_info_list = self._pipe_res['pdf_info']
content_list = union_make(
pdf_info_list,
MakeMode.STANDARD_FORMAT,
DropMode.NONE,
drop_mode,
image_dir_or_bucket_prefix,
)
return content_list
def dump_content_list(
self,
writer: DataWriter,
file_path: str,
image_dir_or_bucket_prefix: str,
drop_mode=DropMode.NONE,
):
"""Dump Content List.
Args:
writer (DataWriter): File writer handle
file_path (str): The file location of content list
image_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.NONE.
"""
content_list = self.get_content_list(
image_dir_or_bucket_prefix, drop_mode=drop_mode,
)
writer.write_string(
file_path, json.dumps(content_list, ensure_ascii=False, indent=4)
)
def get_middle_json(self) -> str:
"""Get middle json.
Returns:
str: The content of middle json
"""
return json.dumps(self._pipe_res, ensure_ascii=False, indent=4)
def dump_middle_json(self, writer: DataWriter, file_path: str):
"""Dump the result of pipeline.
......@@ -74,9 +128,8 @@ class PipeResult:
writer (DataWriter): File writer handler
file_path (str): The file location of middle json
"""
writer.write_string(
file_path, json.dumps(self._pipe_res, ensure_ascii=False, indent=4)
)
middle_json = self.get_middle_json()
writer.write_string(file_path, middle_json)
def draw_layout(self, file_path: str) -> None:
"""Draw the layout.
......@@ -123,7 +176,7 @@ class PipeResult:
Returns:
str: compress the pipeline result and return
"""
return JsonCompressor.compress_json(self.pdf_mid_data)
return JsonCompressor.compress_json(self._pipe_res)
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
......
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import Dataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
def parse_pdf_by_ocr(dataset: Dataset,
model_list,
imageWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
):
return pdf_parse_union(model_list,
dataset,
imageWriter,
SupportedPdfParseMethod.OCR,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
lang=lang,
)
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import Dataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
def parse_pdf_by_txt(
dataset: Dataset,
model_list,
imageWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
):
return pdf_parse_union(model_list,
dataset,
imageWriter,
SupportedPdfParseMethod.TXT,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
lang=lang,
)
import copy
import os
import re
import statistics
import time
from typing import List
......@@ -13,11 +14,12 @@ from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.data.dataset import Dataset, PageableData
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
from magic_pdf.model.magic_model import MagicModel
from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text, llm_aided_title
try:
import torchtext
......@@ -28,15 +30,15 @@ except ImportError:
pass
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.para.para_split_v3 import para_split
from magic_pdf.post_proc.para_split_v3 import para_split
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split_v2
from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans_v2, fix_discarded_block
from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans, remove_overlaps_min_spans
from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans, \
remove_overlaps_min_spans, check_chars_is_overlap_in_span
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
def __replace_STX_ETX(text_str: str):
......@@ -63,11 +65,22 @@ def __replace_0xfffd(text_str: str):
return s
return text_str
# 连写字符拆分
def __replace_ligatures(text: str):
ligatures = {
'fi': 'fi', 'fl': 'fl', 'ff': 'ff', 'ffi': 'ffi', 'ffl': 'ffl', 'ſt': 'ft', 'st': 'st'
}
return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
def chars_to_content(span):
# 检查span中的char是否为空
if len(span['chars']) == 0:
pass
# span['content'] = ''
elif check_chars_is_overlap_in_span(span['chars']):
pass
else:
# 先给chars按char['bbox']的中心点的x坐标排序
span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2)
......@@ -78,11 +91,16 @@ def chars_to_content(span):
content = ''
for char in span['chars']:
# 如果下一个char的x0和上一个char的x1距离超过一个字符宽度,则需要在中间插入一个空格
if char['bbox'][0] - span['chars'][span['chars'].index(char) - 1]['bbox'][2] > char_avg_width:
content += ' '
content += char['c']
# 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格
char1 = char
char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None
if char2 and char2['bbox'][0] - char1['bbox'][2] > char_avg_width * 0.25 and char['c'] != ' ' and char2['c'] != ' ':
content += f"{char['c']} "
else:
content += char['c']
content = __replace_ligatures(content)
span['content'] = __replace_0xfffd(content)
del span['chars']
......@@ -98,6 +116,10 @@ def fill_char_in_spans(spans, all_chars):
spans = sorted(spans, key=lambda x: x['bbox'][1])
for char in all_chars:
# 跳过非法bbox的char
x1, y1, x2, y2 = char['bbox']
if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
continue
for span in spans:
if calculate_char_in_span(char['bbox'], span['bbox'], char['c']):
span['chars'].append(char)
......@@ -152,14 +174,16 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
# cid用0xfffd表示,连字符拆开
# text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
# cid用0xfffd表示,连字符不拆开
text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
all_pymu_chars = []
for block in text_blocks_raw:
for line in block['lines']:
cosine, sine = line['dir']
if abs (cosine) < 0.9 or abs(sine) > 0.1:
if abs(cosine) < 0.9 or abs(sine) > 0.1:
continue
for span in line['spans']:
all_pymu_chars.extend(span['chars'])
......@@ -255,19 +279,23 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
return spans
def replace_text_span(pymu_spans, ocr_spans):
return list(filter(lambda x: x['type'] != ContentType.Text, ocr_spans)) + pymu_spans
def model_init(model_name: str):
from transformers import LayoutLMv3ForTokenClassification
device = get_device()
if torch.cuda.is_available():
device = torch.device('cuda')
if torch.cuda.is_bf16_supported():
supports_bfloat16 = True
else:
supports_bfloat16 = False
elif str(device).startswith("npu"):
import torch_npu
if torch_npu.npu.is_available():
device = torch.device('npu')
supports_bfloat16 = False
else:
device = torch.device('cpu')
supports_bfloat16 = False
else:
device = torch.device('cpu')
supports_bfloat16 = False
......@@ -345,6 +373,8 @@ def cal_block_index(fix_blocks, sorted_bboxes):
# 使用xycut排序
block_bboxes = []
for block in fix_blocks:
# 如果block['bbox']任意值小于0,将其置为0
block['bbox'] = [max(0, x) for x in block['bbox']]
block_bboxes.append(block['bbox'])
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
......@@ -738,6 +768,11 @@ def parse_page_core(
"""重排block"""
sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
"""block内重排(img和table的block内多个caption或footnote的排序)"""
for block in sorted_blocks:
if block['type'] in [BlockType.Image, BlockType.Table]:
block['blocks'] = sorted(block['blocks'], key=lambda b: b['index'])
"""获取QA需要外置的list"""
images, tables, interline_equations = get_qa_need_list_v2(sorted_blocks)
......@@ -819,13 +854,29 @@ def pdf_parse_union(
"""分段"""
para_split(pdf_info_dict)
"""llm优化"""
llm_aided_config = get_llm_aided_config()
if llm_aided_config is not None:
"""公式优化"""
formula_aided_config = llm_aided_config.get('formula_aided', None)
if formula_aided_config is not None:
llm_aided_formula(pdf_info_dict, formula_aided_config)
"""文本优化"""
text_aided_config = llm_aided_config.get('text_aided', None)
if text_aided_config is not None:
llm_aided_text(pdf_info_dict, text_aided_config)
"""标题优化"""
title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None:
llm_aided_title(pdf_info_dict, title_aided_config)
"""dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict)
new_pdf_info_dict = {
'pdf_info': pdf_info_list,
}
clean_memory()
clean_memory(get_device())
return new_pdf_info_dict
......
from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe):
def __init__(
self,
dataset: Dataset,
model_list: list,
image_writer: DataWriter,
is_debug: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
super().__init__(
dataset,
model_list,
image_writer,
is_debug,
start_page_id,
end_page_id,
lang,
layout_model,
formula_enable,
table_enable,
)
def pipe_classify(self):
pass
def pipe_analyze(self):
self.infer_res = doc_analyze(
self.dataset,
ocr=True,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
def pipe_parse(self):
self.pdf_mid_data = parse_ocr_pdf(
self.dataset,
self.infer_res,
self.image_writer,
is_debug=self.is_debug,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info('ocr_pipe mk content list finished')
return result
def pipe_mk_markdown(
self,
img_parent_path: str,
drop_mode=DropMode.WHOLE_PDF,
md_make_mode=MakeMode.MM_MD,
):
result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
logger.info(f'ocr_pipe mk {md_make_mode} finished')
return result
from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_txt_pdf
class TXTPipe(AbsPipe):
def __init__(self, dataset: Dataset, model_list: list, image_writer: DataWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
super().__init__(dataset, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable)
def pipe_classify(self):
pass
def pipe_analyze(self):
self.model_list = doc_analyze(self.dataset, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self):
self.pdf_mid_data = parse_txt_pdf(self.dataset, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info('txt_pipe mk content list finished')
return result
def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
logger.info(f'txt_pipe mk {md_make_mode} finished')
return result
import json
from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.commons import join_path
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_ocr_pdf, parse_union_pdf
class UNIPipe(AbsPipe):
def __init__(
self,
dataset: Dataset,
jso_useful_key: dict,
image_writer: DataWriter,
is_debug: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
self.pdf_type = jso_useful_key['_pdf_type']
super().__init__(
dataset,
jso_useful_key['model_list'],
image_writer,
is_debug,
start_page_id,
end_page_id,
lang,
layout_model,
formula_enable,
table_enable,
)
if len(self.model_list) == 0:
self.input_model_is_empty = True
else:
self.input_model_is_empty = False
def pipe_classify(self):
self.pdf_type = AbsPipe.classify(self.pdf_bytes)
def pipe_analyze(self):
if self.pdf_type == self.PIP_TXT:
self.model_list = doc_analyze(
self.dataset,
ocr=False,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
elif self.pdf_type == self.PIP_OCR:
self.model_list = doc_analyze(
self.dataset,
ocr=True,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
def pipe_parse(self):
if self.pdf_type == self.PIP_TXT:
self.pdf_mid_data = parse_union_pdf(
self.dataset,
self.model_list,
self.image_writer,
is_debug=self.is_debug,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
elif self.pdf_type == self.PIP_OCR:
self.pdf_mid_data = parse_ocr_pdf(
self.dataset,
self.model_list,
self.image_writer,
is_debug=self.is_debug,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
)
def pipe_mk_uni_format(
self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON
):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info('uni_pipe mk content list finished')
return result
def pipe_mk_markdown(
self,
img_parent_path: str,
drop_mode=DropMode.WHOLE_PDF,
md_make_mode=MakeMode.MM_MD,
):
result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
logger.info(f'uni_pipe mk {md_make_mode} finished')
return result
if __name__ == '__main__':
# 测试
from magic_pdf.data.data_reader_writer import DataReader
drw = DataReader(r'D:/project/20231108code-clean')
pdf_file_path = r'linshixuqiu\19983-00.pdf'
model_file_path = r'linshixuqiu\19983-00.json'
pdf_bytes = drw.read(pdf_file_path)
model_json_txt = drw.read(model_file_path).decode()
model_list = json.loads(model_json_txt)
write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00'
img_bucket_path = 'imgs'
img_writer = DataWriter(join_path(write_path, img_bucket_path))
# pdf_type = UNIPipe.classify(pdf_bytes)
# jso_useful_key = {
# "_pdf_type": pdf_type,
# "model_list": model_list
# }
jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
pipe = UNIPipe(pdf_bytes, jso_useful_key, img_writer)
pipe.pipe_classify()
pipe.pipe_parse()
md_content = pipe.pipe_mk_markdown(img_bucket_path)
content_list = pipe.pipe_mk_uni_format(img_bucket_path)
md_writer = DataWriter(write_path)
md_writer.write_string('19983-00.md', md_content)
md_writer.write_string(
'19983-00.json', json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4)
)
md_writer.write_string('19983-00.txt', str(content_list))
# Copyright (c) Opendatalab. All rights reserved.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment