Commit a3a720ea authored by icecraft's avatar icecraft Committed by xu rui
Browse files

refactor: isolate inference and pipeline

parent fdf47155
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Iterator from typing import Callable, Iterator
import fitz import fitz
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.schemas import PageInfo from magic_pdf.data.schemas import PageInfo
from magic_pdf.data.utils import fitz_doc_to_image from magic_pdf.data.utils import fitz_doc_to_image
from magic_pdf.filter import classify
class PageableData(ABC): class PageableData(ABC):
...@@ -28,6 +30,14 @@ class PageableData(ABC): ...@@ -28,6 +30,14 @@ class PageableData(ABC):
""" """
pass pass
@abstractmethod
def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
pass
@abstractmethod
def insert_text(self, coord, content, fontsize, color):
pass
class Dataset(ABC): class Dataset(ABC):
@abstractmethod @abstractmethod
...@@ -66,6 +76,18 @@ class Dataset(ABC): ...@@ -66,6 +76,18 @@ class Dataset(ABC):
""" """
pass pass
@abstractmethod
def dump_to_file(self, file_path: str):
pass
@abstractmethod
def apply(self, proc: Callable, *args, **kwargs):
pass
@abstractmethod
def classify(self) -> SupportedPdfParseMethod:
pass
class PymuDocDataset(Dataset): class PymuDocDataset(Dataset):
def __init__(self, bits: bytes): def __init__(self, bits: bytes):
...@@ -74,7 +96,8 @@ class PymuDocDataset(Dataset): ...@@ -74,7 +96,8 @@ class PymuDocDataset(Dataset):
Args: Args:
bits (bytes): the bytes of the pdf bits (bytes): the bytes of the pdf
""" """
self._records = [Doc(v) for v in fitz.open('pdf', bits)] self._raw_fitz = fitz.open('pdf', bits)
self._records = [Doc(v) for v in self._raw_fitz]
self._data_bits = bits self._data_bits = bits
self._raw_data = bits self._raw_data = bits
...@@ -109,6 +132,19 @@ class PymuDocDataset(Dataset): ...@@ -109,6 +132,19 @@ class PymuDocDataset(Dataset):
""" """
return self._records[page_id] return self._records[page_id]
def dump_to_file(self, file_path: str):
dir_name = os.path.dirname(file_path)
if dir_name not in ('', '.', '..'):
os.makedirs(dir_name, exist_ok=True)
self._raw_fitz.save(file_path)
def apply(self, proc: Callable, *args, **kwargs):
new_args = tuple([self] + list(args))
return proc(*new_args, **kwargs)
def classify(self) -> SupportedPdfParseMethod:
return classify(self._data_bits)
class ImageDataset(Dataset): class ImageDataset(Dataset):
def __init__(self, bits: bytes): def __init__(self, bits: bytes):
...@@ -118,7 +154,8 @@ class ImageDataset(Dataset): ...@@ -118,7 +154,8 @@ class ImageDataset(Dataset):
bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc. bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
""" """
pdf_bytes = fitz.open(stream=bits).convert_to_pdf() pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
self._records = [Doc(v) for v in fitz.open('pdf', pdf_bytes)] self._raw_fitz = fitz.open('pdf', pdf_bytes)
self._records = [Doc(v) for v in self._raw_fitz]
self._raw_data = bits self._raw_data = bits
self._data_bits = pdf_bytes self._data_bits = pdf_bytes
...@@ -153,9 +190,22 @@ class ImageDataset(Dataset): ...@@ -153,9 +190,22 @@ class ImageDataset(Dataset):
""" """
return self._records[page_id] return self._records[page_id]
def dump_to_file(self, file_path: str):
dir_name = os.path.dirname(file_path)
if dir_name not in ('', '.', '..'):
os.makedirs(dir_name, exist_ok=True)
self._raw_fitz.save(file_path)
def apply(self, proc: Callable, *args, **kwargs):
return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod:
return SupportedPdfParseMethod.OCR
class Doc(PageableData): class Doc(PageableData):
"""Initialized with pymudoc object.""" """Initialized with pymudoc object."""
def __init__(self, doc: fitz.Page): def __init__(self, doc: fitz.Page):
self._doc = doc self._doc = doc
...@@ -192,3 +242,16 @@ class Doc(PageableData): ...@@ -192,3 +242,16 @@ class Doc(PageableData):
def __getattr__(self, name): def __getattr__(self, name):
if hasattr(self._doc, name): if hasattr(self._doc, name):
return getattr(self._doc, name) return getattr(self._doc, name)
def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
self._doc.draw_rect(
rect_coords,
color=color,
fill=fill,
fill_opacity=fill_opacity,
width=width,
overlay=overlay,
)
def insert_text(self, coord, content, fontsize, color):
self._doc.insert_text(coord, content, fontsize=fontsize, color=color)
from magic_pdf.config.drop_reason import DropReason
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.filter.pdf_classify_by_type import classify as do_classify
from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
def classify(pdf_bytes: bytes) -> SupportedPdfParseMethod:
"""根据pdf的元数据,判断是文本pdf,还是ocr pdf."""
pdf_meta = pdf_meta_scan(pdf_bytes)
if pdf_meta.get('_need_drop', False): # 如果返回了需要丢弃的标志,则抛出异常
raise Exception(f"pdf meta_scan need_drop,reason is {pdf_meta['_drop_reason']}")
else:
is_encrypted = pdf_meta['is_encrypted']
is_needs_password = pdf_meta['is_needs_password']
if is_encrypted or is_needs_password: # 加密的,需要密码的,没有页面的,都不处理
raise Exception(f'pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}')
else:
is_text_pdf, results = do_classify(
pdf_meta['total_page'],
pdf_meta['page_width_pts'],
pdf_meta['page_height_pts'],
pdf_meta['image_info_per_page'],
pdf_meta['text_len_per_page'],
pdf_meta['imgs_per_page'],
pdf_meta['text_layout_per_page'],
pdf_meta['invalid_chars'],
)
if is_text_pdf:
return SupportedPdfParseMethod.TXT
else:
return SupportedPdfParseMethod.OCR
import fitz import fitz
from magic_pdf.config.constants import CROSS_PAGE from magic_pdf.config.constants import CROSS_PAGE
from magic_pdf.config.ocr_content_type import BlockType, CategoryId, ContentType from magic_pdf.config.ocr_content_type import (BlockType, CategoryId,
from magic_pdf.data.dataset import PymuDocDataset ContentType)
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.commons import fitz # PyMuPDF
from magic_pdf.model.magic_model import MagicModel from magic_pdf.model.magic_model import MagicModel
...@@ -194,7 +196,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -194,7 +196,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
) )
# Save the PDF # Save the PDF
pdf_docs.save(f'{out_path}/{filename}_layout.pdf') pdf_docs.save(f'{out_path}/{filename}')
def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename): def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
...@@ -282,18 +284,17 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -282,18 +284,17 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False) draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False)
# Save the PDF # Save the PDF
pdf_docs.save(f'{out_path}/{filename}_spans.pdf') pdf_docs.save(f'{out_path}/{filename}')
def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename): def draw_model_bbox(model_list, dataset: Dataset, out_path, filename):
dropped_bbox_list = [] dropped_bbox_list = []
tables_body_list, tables_caption_list, tables_footnote_list = [], [], [] tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], [] imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
titles_list = [] titles_list = []
texts_list = [] texts_list = []
interequations_list = [] interequations_list = []
pdf_docs = fitz.open('pdf', pdf_bytes) magic_model = MagicModel(model_list, dataset)
magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
for i in range(len(model_list)): for i in range(len(model_list)):
page_dropped_list = [] page_dropped_list = []
tables_body, tables_caption, tables_footnote = [], [], [] tables_body, tables_caption, tables_footnote = [], [], []
...@@ -337,7 +338,8 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -337,7 +338,8 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
dropped_bbox_list.append(page_dropped_list) dropped_bbox_list.append(page_dropped_list)
imgs_footnote_list.append(imgs_footnote) imgs_footnote_list.append(imgs_footnote)
for i, page in enumerate(pdf_docs): for i in range(len(dataset)):
page = dataset.get_page(i)
draw_bbox_with_number( draw_bbox_with_number(
i, dropped_bbox_list, page, [158, 158, 158], True i, dropped_bbox_list, page, [158, 158, 158], True
) # color ! ) # color !
...@@ -352,7 +354,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -352,7 +354,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True) draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
# Save the PDF # Save the PDF
pdf_docs.save(f'{out_path}/{filename}_model.pdf') dataset.dump_to_file(f'{out_path}/{filename}')
def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename): def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
......
import time import time
import fitz import fitz
import numpy as np import numpy as np
from loguru import logger from loguru import logger
import magic_pdf.model as model_config
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_formula_config get_layout_config,
get_local_models_dir,
get_table_recog_config)
from magic_pdf.model.model_list import MODEL from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config from magic_pdf.model.types import InferenceResult
def dict_compare(d1, d2): def dict_compare(d1, d2):
...@@ -25,19 +30,25 @@ def remove_duplicates_dicts(lst): ...@@ -25,19 +30,25 @@ def remove_duplicates_dicts(lst):
return unique_dicts return unique_dicts
def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list: def load_images_from_pdf(
pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None
) -> list:
try: try:
from PIL import Image from PIL import Image
except ImportError: except ImportError:
logger.error("Pillow not installed, please install by pip.") logger.error('Pillow not installed, please install by pip.')
exit(1) exit(1)
images = [] images = []
with fitz.open("pdf", pdf_bytes) as doc: with fitz.open('pdf', pdf_bytes) as doc:
pdf_page_num = doc.page_count pdf_page_num = doc.page_count
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1 end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else pdf_page_num - 1
)
if end_page_id > pdf_page_num - 1: if end_page_id > pdf_page_num - 1:
logger.warning("end_page_id is out of range, use images length") logger.warning('end_page_id is out of range, use images length')
end_page_id = pdf_page_num - 1 end_page_id = pdf_page_num - 1
for index in range(0, doc.page_count): for index in range(0, doc.page_count):
...@@ -50,11 +61,11 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id ...@@ -50,11 +61,11 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
if pm.width > 4500 or pm.height > 4500: if pm.width > 4500 or pm.height > 4500:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples) img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
img = np.array(img) img = np.array(img)
img_dict = {"img": img, "width": pm.width, "height": pm.height} img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
else: else:
img_dict = {"img": [], "width": 0, "height": 0} img_dict = {'img': [], 'width': 0, 'height': 0}
images.append(img_dict) images.append(img_dict)
return images return images
...@@ -69,117 +80,150 @@ class ModelSingleton: ...@@ -69,117 +80,150 @@ class ModelSingleton:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None): def get_model(
self,
ocr: bool,
show_log: bool,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
key = (ocr, show_log, lang, layout_model, formula_enable, table_enable) key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
if key not in self._models: if key not in self._models:
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model, self._models[key] = custom_model_init(
formula_enable=formula_enable, table_enable=table_enable) ocr=ocr,
show_log=show_log,
lang=lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
return self._models[key] return self._models[key]
def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None, def custom_model_init(
layout_model=None, formula_enable=None, table_enable=None): ocr: bool = False,
show_log: bool = False,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
model = None model = None
if model_config.__model_mode__ == "lite": if model_config.__model_mode__ == 'lite':
logger.warning("The Lite mode is provided for developers to conduct testing only, and the output quality is " logger.warning(
"not guaranteed to be reliable.") 'The Lite mode is provided for developers to conduct testing only, and the output quality is '
'not guaranteed to be reliable.'
)
model = MODEL.Paddle model = MODEL.Paddle
elif model_config.__model_mode__ == "full": elif model_config.__model_mode__ == 'full':
model = MODEL.PEK model = MODEL.PEK
if model_config.__use_inside_model__: if model_config.__use_inside_model__:
model_init_start = time.time() model_init_start = time.time()
if model == MODEL.Paddle: if model == MODEL.Paddle:
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang) custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
elif model == MODEL.PEK: elif model == MODEL.PEK:
from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.pdf_extract_kit import CustomPEKModel
# 从配置文件读取model-dir和device # 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir() local_models_dir = get_local_models_dir()
device = get_device() device = get_device()
layout_config = get_layout_config() layout_config = get_layout_config()
if layout_model is not None: if layout_model is not None:
layout_config["model"] = layout_model layout_config['model'] = layout_model
formula_config = get_formula_config() formula_config = get_formula_config()
if formula_enable is not None: if formula_enable is not None:
formula_config["enable"] = formula_enable formula_config['enable'] = formula_enable
table_config = get_table_recog_config() table_config = get_table_recog_config()
if table_enable is not None: if table_enable is not None:
table_config["enable"] = table_enable table_config['enable'] = table_enable
model_input = { model_input = {
"ocr": ocr, 'ocr': ocr,
"show_log": show_log, 'show_log': show_log,
"models_dir": local_models_dir, 'models_dir': local_models_dir,
"device": device, 'device': device,
"table_config": table_config, 'table_config': table_config,
"layout_config": layout_config, 'layout_config': layout_config,
"formula_config": formula_config, 'formula_config': formula_config,
"lang": lang, 'lang': lang,
} }
custom_model = CustomPEKModel(**model_input) custom_model = CustomPEKModel(**model_input)
else: else:
logger.error("Not allow model_name!") logger.error('Not allow model_name!')
exit(1) exit(1)
model_init_cost = time.time() - model_init_start model_init_cost = time.time() - model_init_start
logger.info(f"model init cost: {model_init_cost}") logger.info(f'model init cost: {model_init_cost}')
else: else:
logger.error("use_inside_model is False, not allow to use inside model") logger.error('use_inside_model is False, not allow to use inside model')
exit(1) exit(1)
return custom_model return custom_model
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, def doc_analyze(
start_page_id=0, end_page_id=None, lang=None, dataset: Dataset,
layout_model=None, formula_enable=None, table_enable=None): ocr: bool = False,
show_log: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
) -> InferenceResult:
if lang == "": if lang == '':
lang = None lang = None
model_manager = ModelSingleton() model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable) custom_model = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable
with fitz.open("pdf", pdf_bytes) as doc: )
pdf_page_num = doc.page_count
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
if end_page_id > pdf_page_num - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = pdf_page_num - 1
images = load_images_from_pdf(pdf_bytes, start_page_id=start_page_id, end_page_id=end_page_id)
model_json = [] model_json = []
doc_analyze_start = time.time() doc_analyze_start = time.time()
for index, img_dict in enumerate(images): if end_page_id is None:
img = img_dict["img"] end_page_id = len(dataset)
page_width = img_dict["width"]
page_height = img_dict["height"] for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
img = img_dict['img']
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id: if start_page_id <= index <= end_page_id:
page_start = time.time() page_start = time.time()
result = custom_model(img) result = custom_model(img)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----') logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
else: else:
result = [] result = []
page_info = {"page_no": index, "height": page_height, "width": page_width}
page_dict = {"layout_dets": result, "page_info": page_info} page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict) model_json.append(page_dict)
gc_start = time.time() gc_start = time.time()
clean_memory() clean_memory()
gc_time = round(time.time() - gc_start, 2) gc_time = round(time.time() - gc_start, 2)
logger.info(f"gc time: {gc_time}") logger.info(f'gc time: {gc_time}')
doc_analyze_time = round(time.time() - doc_analyze_start, 2) doc_analyze_time = round(time.time() - doc_analyze_start, 2)
doc_analyze_speed = round( (end_page_id + 1 - start_page_id) / doc_analyze_time, 2) doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
logger.info(f"doc analyze time: {round(time.time() - doc_analyze_start, 2)}," logger.info(
f" speed: {doc_analyze_speed} pages/second") f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
f' speed: {doc_analyze_speed} pages/second'
)
return model_json return InferenceResult(model_json, dataset)
import copy
import json
import os
from typing import Callable
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.filter import classify
from magic_pdf.libs.draw_bbox import draw_model_bbox
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
from magic_pdf.pipe.types import PipeResult
class InferenceResult:
def __init__(self, inference_results: list, dataset: Dataset):
self._infer_res = inference_results
self._dataset = dataset
def draw_model(self, file_path: str) -> None:
dir_name = os.path.dirname(file_path)
base_name = os.path.basename(file_path)
if not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
draw_model_bbox(
copy.deepcopy(self._infer_res), self._dataset, dir_name, base_name
)
def dump_model(self, writer: DataWriter, file_path: str):
writer.write_string(
file_path, json.dumps(self._infer_res, ensure_ascii=False, indent=4)
)
def get_infer_res(self):
return self._infer_res
def apply(self, proc: Callable, *args, **kwargs):
return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
def pipe_auto_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
def proc(*args, **kwargs) -> PipeResult:
res = pdf_parse_union(*args, **kwargs)
return PipeResult(res, self._dataset)
pdf_proc_method = classify(self._dataset.data_bits())
if pdf_proc_method == SupportedPdfParseMethod.TXT:
return self.apply(
proc,
self._dataset,
imageWriter,
SupportedPdfParseMethod.TXT,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
)
else:
return self.apply(
proc,
self._dataset,
imageWriter,
SupportedPdfParseMethod.OCR,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
)
def pipe_txt_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
def proc(*args, **kwargs) -> PipeResult:
res = pdf_parse_union(*args, **kwargs)
return PipeResult(res, self._dataset)
return self.apply(
proc,
self._dataset,
imageWriter,
SupportedPdfParseMethod.TXT,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
)
def pipe_ocr_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
def proc(*args, **kwargs) -> PipeResult:
res = pdf_parse_union(*args, **kwargs)
return PipeResult(res, self._dataset)
return self.apply(
proc,
self._dataset,
imageWriter,
SupportedPdfParseMethod.TXT,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
)
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.data.dataset import Dataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
def parse_pdf_by_ocr(pdf_bytes, def parse_pdf_by_ocr(dataset: Dataset,
model_list, model_list,
imageWriter, imageWriter,
start_page_id=0, start_page_id=0,
...@@ -11,9 +11,8 @@ def parse_pdf_by_ocr(pdf_bytes, ...@@ -11,9 +11,8 @@ def parse_pdf_by_ocr(pdf_bytes,
debug_mode=False, debug_mode=False,
lang=None, lang=None,
): ):
dataset = PymuDocDataset(pdf_bytes) return pdf_parse_union(model_list,
return pdf_parse_union(dataset, dataset,
model_list,
imageWriter, imageWriter,
SupportedPdfParseMethod.OCR, SupportedPdfParseMethod.OCR,
start_page_id=start_page_id, start_page_id=start_page_id,
......
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.data.dataset import Dataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
def parse_pdf_by_txt( def parse_pdf_by_txt(
pdf_bytes, dataset: Dataset,
model_list, model_list,
imageWriter, imageWriter,
start_page_id=0, start_page_id=0,
...@@ -12,9 +12,8 @@ def parse_pdf_by_txt( ...@@ -12,9 +12,8 @@ def parse_pdf_by_txt(
debug_mode=False, debug_mode=False,
lang=None, lang=None,
): ):
dataset = PymuDocDataset(pdf_bytes) return pdf_parse_union(model_list,
return pdf_parse_union(dataset, dataset,
model_list,
imageWriter, imageWriter,
SupportedPdfParseMethod.TXT, SupportedPdfParseMethod.TXT,
start_page_id=start_page_id, start_page_id=start_page_id,
......
...@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod ...@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from magic_pdf.config.drop_reason import DropReason from magic_pdf.config.drop_reason import DropReason
from magic_pdf.config.make_content_config import DropMode, MakeMode from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.dict2md.ocr_mkcontent import union_make from magic_pdf.dict2md.ocr_mkcontent import union_make
from magic_pdf.filter.pdf_classify_by_type import classify from magic_pdf.filter.pdf_classify_by_type import classify
from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
...@@ -14,9 +15,9 @@ class AbsPipe(ABC): ...@@ -14,9 +15,9 @@ class AbsPipe(ABC):
PIP_OCR = 'ocr' PIP_OCR = 'ocr'
PIP_TXT = 'txt' PIP_TXT = 'txt'
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False, def __init__(self, dataset: Dataset, model_list: list, image_writer: DataWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None): start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
self.pdf_bytes = pdf_bytes self.dataset = Dataset
self.model_list = model_list self.model_list = model_list
self.image_writer = image_writer self.image_writer = image_writer
self.pdf_mid_data = None # 未压缩 self.pdf_mid_data = None # 未压缩
......
...@@ -2,40 +2,79 @@ from loguru import logger ...@@ -2,40 +2,79 @@ from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.AbsPipe import AbsPipe from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_ocr_pdf from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe): class OCRPipe(AbsPipe):
def __init__(
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False, self,
start_page_id=0, end_page_id=None, lang=None, dataset: Dataset,
layout_model=None, formula_enable=None, table_enable=None): model_list: list,
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang, image_writer: DataWriter,
layout_model, formula_enable, table_enable) is_debug: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
super().__init__(
dataset,
model_list,
image_writer,
is_debug,
start_page_id,
end_page_id,
lang,
layout_model,
formula_enable,
table_enable,
)
def pipe_classify(self): def pipe_classify(self):
pass pass
def pipe_analyze(self): def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=True, self.infer_res = doc_analyze(
start_page_id=self.start_page_id, end_page_id=self.end_page_id, self.dataset,
lang=self.lang, layout_model=self.layout_model, ocr=True,
formula_enable=self.formula_enable, table_enable=self.table_enable) start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
def pipe_parse(self): def pipe_parse(self):
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, self.pdf_mid_data = parse_ocr_pdf(
start_page_id=self.start_page_id, end_page_id=self.end_page_id, self.dataset,
lang=self.lang, layout_model=self.layout_model, self.infer_res,
formula_enable=self.formula_enable, table_enable=self.table_enable) self.image_writer,
is_debug=self.is_debug,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info('ocr_pipe mk content list finished') logger.info('ocr_pipe mk content list finished')
return result return result
def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD): def pipe_mk_markdown(
self,
img_parent_path: str,
drop_mode=DropMode.WHOLE_PDF,
md_make_mode=MakeMode.MM_MD,
):
result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode) result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
logger.info(f'ocr_pipe mk {md_make_mode} finished') logger.info(f'ocr_pipe mk {md_make_mode} finished')
return result return result
...@@ -2,6 +2,7 @@ from loguru import logger ...@@ -2,6 +2,7 @@ from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.AbsPipe import AbsPipe from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_txt_pdf from magic_pdf.user_api import parse_txt_pdf
...@@ -9,23 +10,23 @@ from magic_pdf.user_api import parse_txt_pdf ...@@ -9,23 +10,23 @@ from magic_pdf.user_api import parse_txt_pdf
class TXTPipe(AbsPipe): class TXTPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False, def __init__(self, dataset: Dataset, model_list: list, image_writer: DataWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None, start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None): layout_model=None, formula_enable=None, table_enable=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang, super().__init__(dataset, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable) layout_model, formula_enable, table_enable)
def pipe_classify(self): def pipe_classify(self):
pass pass
def pipe_analyze(self): def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=False, self.model_list = doc_analyze(self.dataset, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model, lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable) formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self): def pipe_parse(self):
self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, self.pdf_mid_data = parse_txt_pdf(self.dataset, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model, lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable) formula_enable=self.formula_enable, table_enable=self.table_enable)
......
...@@ -4,6 +4,7 @@ from loguru import logger ...@@ -4,6 +4,7 @@ from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.commons import join_path from magic_pdf.libs.commons import join_path
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.AbsPipe import AbsPipe from magic_pdf.pipe.AbsPipe import AbsPipe
...@@ -12,12 +13,32 @@ from magic_pdf.user_api import parse_ocr_pdf, parse_union_pdf ...@@ -12,12 +13,32 @@ from magic_pdf.user_api import parse_ocr_pdf, parse_union_pdf
class UNIPipe(AbsPipe): class UNIPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: DataWriter, is_debug: bool = False, def __init__(
start_page_id=0, end_page_id=None, lang=None, self,
layout_model=None, formula_enable=None, table_enable=None): dataset: Dataset,
jso_useful_key: dict,
image_writer: DataWriter,
is_debug: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
self.pdf_type = jso_useful_key['_pdf_type'] self.pdf_type = jso_useful_key['_pdf_type']
super().__init__(pdf_bytes, jso_useful_key['model_list'], image_writer, is_debug, start_page_id, end_page_id, super().__init__(
lang, layout_model, formula_enable, table_enable) dataset,
jso_useful_key['model_list'],
image_writer,
is_debug,
start_page_id,
end_page_id,
lang,
layout_model,
formula_enable,
table_enable,
)
if len(self.model_list) == 0: if len(self.model_list) == 0:
self.input_model_is_empty = True self.input_model_is_empty = True
else: else:
...@@ -28,35 +49,66 @@ class UNIPipe(AbsPipe): ...@@ -28,35 +49,66 @@ class UNIPipe(AbsPipe):
def pipe_analyze(self): def pipe_analyze(self):
if self.pdf_type == self.PIP_TXT: if self.pdf_type == self.PIP_TXT:
self.model_list = doc_analyze(self.pdf_bytes, ocr=False, self.model_list = doc_analyze(
start_page_id=self.start_page_id, end_page_id=self.end_page_id, self.dataset,
lang=self.lang, layout_model=self.layout_model, ocr=False,
formula_enable=self.formula_enable, table_enable=self.table_enable) start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
elif self.pdf_type == self.PIP_OCR: elif self.pdf_type == self.PIP_OCR:
self.model_list = doc_analyze(self.pdf_bytes, ocr=True, self.model_list = doc_analyze(
start_page_id=self.start_page_id, end_page_id=self.end_page_id, self.dataset,
lang=self.lang, layout_model=self.layout_model, ocr=True,
formula_enable=self.formula_enable, table_enable=self.table_enable) start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
def pipe_parse(self): def pipe_parse(self):
if self.pdf_type == self.PIP_TXT: if self.pdf_type == self.PIP_TXT:
self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer, self.pdf_mid_data = parse_union_pdf(
is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty, self.dataset,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, self.model_list,
lang=self.lang, layout_model=self.layout_model, self.image_writer,
formula_enable=self.formula_enable, table_enable=self.table_enable) is_debug=self.is_debug,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
elif self.pdf_type == self.PIP_OCR: elif self.pdf_type == self.PIP_OCR:
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, self.pdf_mid_data = parse_ocr_pdf(
self.dataset,
self.model_list,
self.image_writer,
is_debug=self.is_debug, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id,
lang=self.lang) end_page_id=self.end_page_id,
lang=self.lang,
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON): )
def pipe_mk_uni_format(
self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON
):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info('uni_pipe mk content list finished') logger.info('uni_pipe mk content list finished')
return result return result
def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD): def pipe_mk_markdown(
self,
img_parent_path: str,
drop_mode=DropMode.WHOLE_PDF,
md_make_mode=MakeMode.MM_MD,
):
result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode) result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
logger.info(f'uni_pipe mk {md_make_mode} finished') logger.info(f'uni_pipe mk {md_make_mode} finished')
return result return result
...@@ -65,6 +117,7 @@ class UNIPipe(AbsPipe): ...@@ -65,6 +117,7 @@ class UNIPipe(AbsPipe):
if __name__ == '__main__': if __name__ == '__main__':
# 测试 # 测试
from magic_pdf.data.data_reader_writer import DataReader from magic_pdf.data.data_reader_writer import DataReader
drw = DataReader(r'D:/project/20231108code-clean') drw = DataReader(r'D:/project/20231108code-clean')
pdf_file_path = r'linshixuqiu\19983-00.pdf' pdf_file_path = r'linshixuqiu\19983-00.pdf'
...@@ -82,10 +135,7 @@ if __name__ == '__main__': ...@@ -82,10 +135,7 @@ if __name__ == '__main__':
# "model_list": model_list # "model_list": model_list
# } # }
jso_useful_key = { jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
'_pdf_type': '',
'model_list': model_list
}
pipe = UNIPipe(pdf_bytes, jso_useful_key, img_writer) pipe = UNIPipe(pdf_bytes, jso_useful_key, img_writer)
pipe.pipe_classify() pipe.pipe_classify()
pipe.pipe_parse() pipe.pipe_parse()
...@@ -94,5 +144,7 @@ if __name__ == '__main__': ...@@ -94,5 +144,7 @@ if __name__ == '__main__':
md_writer = DataWriter(write_path) md_writer = DataWriter(write_path)
md_writer.write_string('19983-00.md', md_content) md_writer.write_string('19983-00.md', md_content)
md_writer.write_string('19983-00.json', json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4)) md_writer.write_string(
'19983-00.json', json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4)
)
md_writer.write_string('19983-00.txt', str(content_list)) md_writer.write_string('19983-00.txt', str(content_list))
import json
import os
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.dict2md.ocr_mkcontent import union_make
from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
draw_span_bbox)
from magic_pdf.libs.json_compressor import JsonCompressor
class PipeResult:
def __init__(self, pipe_res, dataset: Dataset):
self._pipe_res = pipe_res
self._dataset = dataset
def dump_md(self, writer: DataWriter, file_path: str, img_dir_or_bucket_prefix: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
pdf_info_list = self._pipe_res['pdf_info']
md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_dir_or_bucket_prefix)
writer.write_string(file_path, md_content)
def dump_content_list(self, writer: DataWriter, file_path: str, image_dir_or_bucket_prefix: str, drop_mode=DropMode.NONE):
pdf_info_list = self._pipe_res['pdf_info']
content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, image_dir_or_bucket_prefix)
writer.write_string(file_path, json.dumps(content_list, ensure_ascii=False, indent=4))
def dump_middle_json(self, writer: DataWriter, file_path: str):
writer.write_string(file_path, json.dumps(self._pipe_res, ensure_ascii=False, indent=4))
def draw_layout(self, file_path: str) -> None:
dir_name = os.path.dirname(file_path)
base_name = os.path.basename(file_path)
if not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
pdf_info = self._pipe_res['pdf_info']
draw_layout_bbox(pdf_info, self._dataset.data_bits(), dir_name, base_name)
def draw_span(self, file_path: str):
dir_name = os.path.dirname(file_path)
base_name = os.path.basename(file_path)
if not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
pdf_info = self._pipe_res['pdf_info']
draw_span_bbox(pdf_info, self._dataset.data_bits(), dir_name, base_name)
def draw_line_sort(self, file_path: str):
dir_name = os.path.dirname(file_path)
base_name = os.path.basename(file_path)
if not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
pdf_info = self._pipe_res['pdf_info']
draw_line_sort_bbox(pdf_info, self._dataset.data_bits(), dir_name, base_name)
def draw_content_list(self, writer: DataWriter, file_path: str, img_dir_or_bucket_prefix: str, drop_mode=DropMode.WHOLE_PDF):
pdf_info_list = self._pipe_res['pdf_info']
content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_dir_or_bucket_prefix)
writer.write_string(file_path, json.dumps(content_list, ensure_ascii=False, indent=4))
def get_compress_pdf_mid_data(self):
return JsonCompressor.compress_json(self.pdf_mid_data)
import copy
import json as json_parse
import os import os
import click import click
...@@ -7,13 +5,12 @@ import fitz ...@@ -7,13 +5,12 @@ import fitz
from loguru import logger from loguru import logger
import magic_pdf.model as model_config import magic_pdf.model as model_config
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.config.make_content_config import DropMode, MakeMode from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import FileBasedDataWriter from magic_pdf.data.data_reader_writer import FileBasedDataWriter
from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox, from magic_pdf.data.dataset import PymuDocDataset
draw_model_bbox, draw_span_bbox) from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.OCRPipe import OCRPipe from magic_pdf.model.types import InferenceResult
from magic_pdf.pipe.TXTPipe import TXTPipe
from magic_pdf.pipe.UNIPipe import UNIPipe
# from io import BytesIO # from io import BytesIO
# from pypdf import PdfReader, PdfWriter # from pypdf import PdfReader, PdfWriter
...@@ -56,7 +53,11 @@ def prepare_env(output_dir, pdf_file_name, method): ...@@ -56,7 +53,11 @@ def prepare_env(output_dir, pdf_file_name, method):
def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_id=None): def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_id=None):
document = fitz.open('pdf', pdf_bytes) document = fitz.open('pdf', pdf_bytes)
output_document = fitz.open() output_document = fitz.open()
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(document) - 1 end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else len(document) - 1
)
if end_page_id > len(document) - 1: if end_page_id > len(document) - 1:
logger.warning('end_page_id is out of range, use pdf_docs length') logger.warning('end_page_id is out of range, use pdf_docs length')
end_page_id = len(document) - 1 end_page_id = len(document) - 1
...@@ -94,78 +95,123 @@ def do_parse( ...@@ -94,78 +95,123 @@ def do_parse(
f_draw_model_bbox = True f_draw_model_bbox = True
f_draw_line_sort_bbox = True f_draw_line_sort_bbox = True
if lang == "": if lang == '':
lang = None lang = None
pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id, end_page_id) pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
pdf_bytes, start_page_id, end_page_id
)
orig_model_list = copy.deepcopy(model_list) local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name,
parse_method)
image_writer, md_writer = FileBasedDataWriter( image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
local_image_dir), FileBasedDataWriter(local_md_dir) local_md_dir
)
image_dir = str(os.path.basename(local_image_dir)) image_dir = str(os.path.basename(local_image_dir))
ds = PymuDocDataset(pdf_bytes)
if len(model_list) == 0:
if model_config.__use_inside_model__:
if parse_method == 'auto': if parse_method == 'auto':
jso_useful_key = {'_pdf_type': '', 'model_list': model_list} if ds.classify() == SupportedPdfParseMethod.TXT:
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True, infer_result = ds.apply(
# start_page_id=start_page_id, end_page_id=end_page_id, doc_analyze,
ocr=False,
lang=lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
else:
infer_result = ds.apply(
doc_analyze,
ocr=True,
lang=lang, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable) layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
pipe_result = infer_result.pipe_auto_mode(
image_writer, debug_mode=True, lang=lang
)
elif parse_method == 'txt': elif parse_method == 'txt':
pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True, infer_result = ds.apply(
# start_page_id=start_page_id, end_page_id=end_page_id, doc_analyze,
ocr=False,
lang=lang, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable) layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
pipe_result = infer_result.pipe_txt_mode(
image_writer, debug_mode=True, lang=lang
)
elif parse_method == 'ocr': elif parse_method == 'ocr':
pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True, infer_result = ds.apply(
# start_page_id=start_page_id, end_page_id=end_page_id, doc_analyze,
ocr=True,
lang=lang, lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable) layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
pipe_result = infer_result.pipe_ocr_mode(
image_writer, debug_mode=True, lang=lang
)
else: else:
logger.error('unknown parse method') logger.error('unknown parse method')
exit(1) exit(1)
pipe.pipe_classify()
if len(model_list) == 0:
if model_config.__use_inside_model__:
pipe.pipe_analyze()
orig_model_list = copy.deepcopy(pipe.model_list)
else: else:
logger.error('need model list input') logger.error('need model list input')
exit(2) exit(2)
else:
infer_result = InferenceResult(model_list, ds)
if parse_method == 'ocr':
pipe_result = infer_result.pipe_ocr_mode(
image_writer, debug_mode=True, lang=lang
)
elif parse_method == 'txt':
pipe_result = infer_result.pipe_txt_mode(
image_writer, debug_mode=True, lang=lang
)
else:
pipe_result = infer_result.pipe_auto_mode(
image_writer, debug_mode=True, lang=lang
)
if f_draw_model_bbox:
infer_result.draw_model(
os.path.join(local_md_dir, f'{pdf_file_name}_model.pdf')
)
pipe.pipe_parse()
pdf_info = pipe.pdf_mid_data['pdf_info']
if f_draw_layout_bbox: if f_draw_layout_bbox:
draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name) pipe_result.draw_layout(
os.path.join(local_md_dir, f'{pdf_file_name}_layout.pdf')
)
if f_draw_span_bbox: if f_draw_span_bbox:
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name) pipe_result.draw_span(os.path.join(local_md_dir, f'{pdf_file_name}_spans.pdf'))
if f_draw_model_bbox:
draw_model_bbox(copy.deepcopy(orig_model_list), pdf_bytes, local_md_dir, pdf_file_name)
if f_draw_line_sort_bbox: if f_draw_line_sort_bbox:
draw_line_sort_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name) pipe_result.draw_line_sort(
os.path.join(local_md_dir, f'{pdf_file_name}_line_sort.pdf')
)
md_content = pipe.pipe_mk_markdown(image_dir, drop_mode=DropMode.NONE, md_make_mode=f_make_md_mode)
if f_dump_md: if f_dump_md:
md_writer.write_string( pipe_result.dump_md(
md_writer,
f'{pdf_file_name}.md', f'{pdf_file_name}.md',
md_content image_dir,
drop_mode=DropMode.NONE,
md_make_mode=f_make_md_mode,
) )
if f_dump_middle_json: if f_dump_middle_json:
md_writer.write_string( pipe_result.dump_middle_json(md_writer, f'{pdf_file_name}_middle.json')
f'{pdf_file_name}_middle.json',
json_parse.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4)
)
if f_dump_model_json: if f_dump_model_json:
md_writer.write_string( infer_result.dump_model(md_writer, f'{pdf_file_name}_model.json')
f'{pdf_file_name}_model.json',
json_parse.dumps(orig_model_list, ensure_ascii=False, indent=4)
)
if f_dump_orig_pdf: if f_dump_orig_pdf:
md_writer.write( md_writer.write(
...@@ -173,11 +219,12 @@ def do_parse( ...@@ -173,11 +219,12 @@ def do_parse(
pdf_bytes, pdf_bytes,
) )
content_list = pipe.pipe_mk_uni_format(image_dir, drop_mode=DropMode.NONE)
if f_dump_content_list: if f_dump_content_list:
md_writer.write_string( pipe_result.dump_content_list(
md_writer,
f'{pdf_file_name}_content_list.json', f'{pdf_file_name}_content_list.json',
json_parse.dumps(content_list, ensure_ascii=False, indent=4) image_dir,
drop_mode=DropMode.NONE,
) )
logger.info(f'local output dir is {local_md_dir}') logger.info(f'local output dir is {local_md_dir}')
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
from loguru import logger from loguru import logger
from magic_pdf.data.data_reader_writer import DataWriter from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.version import __version__ from magic_pdf.libs.version import __version__
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
...@@ -19,13 +20,21 @@ PARSE_TYPE_TXT = 'txt' ...@@ -19,13 +20,21 @@ PARSE_TYPE_TXT = 'txt'
PARSE_TYPE_OCR = 'ocr' PARSE_TYPE_OCR = 'ocr'
def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False, def parse_txt_pdf(
start_page_id=0, end_page_id=None, lang=None, dataset: Dataset,
*args, **kwargs): model_list: list,
imageWriter: DataWriter,
is_debug=False,
start_page_id=0,
end_page_id=None,
lang=None,
*args,
**kwargs
):
"""解析文本类pdf.""" """解析文本类pdf."""
pdf_info_dict = parse_pdf_by_txt( pdf_info_dict = parse_pdf_by_txt(
pdf_bytes, dataset,
pdf_models, model_list,
imageWriter, imageWriter,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
...@@ -43,13 +52,21 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, i ...@@ -43,13 +52,21 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, i
return pdf_info_dict return pdf_info_dict
def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False, def parse_ocr_pdf(
start_page_id=0, end_page_id=None, lang=None, dataset: Dataset,
*args, **kwargs): model_list: list,
imageWriter: DataWriter,
is_debug=False,
start_page_id=0,
end_page_id=None,
lang=None,
*args,
**kwargs
):
"""解析ocr类pdf.""" """解析ocr类pdf."""
pdf_info_dict = parse_pdf_by_ocr( pdf_info_dict = parse_pdf_by_ocr(
pdf_bytes, dataset,
pdf_models, model_list,
imageWriter, imageWriter,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
...@@ -67,17 +84,24 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, i ...@@ -67,17 +84,24 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, i
return pdf_info_dict return pdf_info_dict
def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False, def parse_union_pdf(
input_model_is_empty: bool = False, dataset: Dataset,
start_page_id=0, end_page_id=None, lang=None, model_list: list,
*args, **kwargs): imageWriter: DataWriter,
is_debug=False,
start_page_id=0,
end_page_id=None,
lang=None,
*args,
**kwargs
):
"""ocr和文本混合的pdf,全部解析出来.""" """ocr和文本混合的pdf,全部解析出来."""
def parse_pdf(method): def parse_pdf(method):
try: try:
return method( return method(
pdf_bytes, dataset,
pdf_models, model_list,
imageWriter, imageWriter,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
...@@ -91,12 +115,12 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, ...@@ -91,12 +115,12 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter,
pdf_info_dict = parse_pdf(parse_pdf_by_txt) pdf_info_dict = parse_pdf(parse_pdf_by_txt)
if pdf_info_dict is None or pdf_info_dict.get('_need_drop', False): if pdf_info_dict is None or pdf_info_dict.get('_need_drop', False):
logger.warning('parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr') logger.warning('parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr')
if input_model_is_empty: if len(model_list) == 0:
layout_model = kwargs.get('layout_model', None) layout_model = kwargs.get('layout_model', None)
formula_enable = kwargs.get('formula_enable', None) formula_enable = kwargs.get('formula_enable', None)
table_enable = kwargs.get('table_enable', None) table_enable = kwargs.get('table_enable', None)
pdf_models = doc_analyze( infer_res = doc_analyze(
pdf_bytes, dataset,
ocr=True, ocr=True,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
...@@ -105,6 +129,7 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, ...@@ -105,6 +129,7 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter,
formula_enable=formula_enable, formula_enable=formula_enable,
table_enable=table_enable, table_enable=table_enable,
) )
model_list = infer_res.get_infer_res()
pdf_info_dict = parse_pdf(parse_pdf_by_ocr) pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
if pdf_info_dict is None: if pdf_info_dict is None:
raise Exception('Both parse_pdf_by_txt and parse_pdf_by_ocr failed.') raise Exception('Both parse_pdf_by_txt and parse_pdf_by_ocr failed.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment