"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "6a992437915f03ae1dc84e00e6b6f735429661ad"
Unverified Commit fa113b57 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1178 from icecraft/refactor/add_user_api

Refactor/add user api
parents 1c10dc55 e4ed6023
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Iterator from typing import Callable, Iterator
import fitz import fitz
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.schemas import PageInfo from magic_pdf.data.schemas import PageInfo
from magic_pdf.data.utils import fitz_doc_to_image from magic_pdf.data.utils import fitz_doc_to_image
from magic_pdf.filter import classify
class PageableData(ABC): class PageableData(ABC):
...@@ -28,6 +30,32 @@ class PageableData(ABC): ...@@ -28,6 +30,32 @@ class PageableData(ABC):
""" """
pass pass
@abstractmethod
def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
"""draw rectangle.
Args:
rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
color (list[float] | None): three element tuple which describe the RGB of the board line, None means no board line
fill (list[float] | None): fill the board with RGB, None means will not fill with color
fill_opacity (float): opacity of the fill, range from [0, 1]
width (float): the width of board
overlay (bool): fill the color in foreground or background. True means fill in background.
"""
pass
@abstractmethod
def insert_text(self, coord, content, fontsize, color):
"""insert text.
Args:
coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
content (str): the text content
fontsize (int): font size of the text
color (list[float] | None): three element tuple which describe the RGB of the board line, None will use the default font color!
"""
pass
class Dataset(ABC): class Dataset(ABC):
@abstractmethod @abstractmethod
...@@ -66,6 +94,43 @@ class Dataset(ABC): ...@@ -66,6 +94,43 @@ class Dataset(ABC):
""" """
pass pass
@abstractmethod
def dump_to_file(self, file_path: str):
"""Dump the file
Args:
file_path (str): the file path
"""
pass
@abstractmethod
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(self, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
pass
@abstractmethod
def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset
Returns:
SupportedPdfParseMethod: _description_
"""
pass
@abstractmethod
def clone(self):
"""clone this dataset
"""
pass
class PymuDocDataset(Dataset): class PymuDocDataset(Dataset):
def __init__(self, bits: bytes): def __init__(self, bits: bytes):
...@@ -74,7 +139,8 @@ class PymuDocDataset(Dataset): ...@@ -74,7 +139,8 @@ class PymuDocDataset(Dataset):
Args: Args:
bits (bytes): the bytes of the pdf bits (bytes): the bytes of the pdf
""" """
self._records = [Doc(v) for v in fitz.open('pdf', bits)] self._raw_fitz = fitz.open('pdf', bits)
self._records = [Doc(v) for v in self._raw_fitz]
self._data_bits = bits self._data_bits = bits
self._raw_data = bits self._raw_data = bits
...@@ -109,6 +175,43 @@ class PymuDocDataset(Dataset): ...@@ -109,6 +175,43 @@ class PymuDocDataset(Dataset):
""" """
return self._records[page_id] return self._records[page_id]
def dump_to_file(self, file_path: str):
"""Dump the file
Args:
file_path (str): the file path
"""
dir_name = os.path.dirname(file_path)
if dir_name not in ('', '.', '..'):
os.makedirs(dir_name, exist_ok=True)
self._raw_fitz.save(file_path)
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(dataset, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset
Returns:
SupportedPdfParseMethod: _description_
"""
return classify(self._data_bits)
def clone(self):
"""clone this dataset
"""
return PymuDocDataset(self._raw_data)
class ImageDataset(Dataset): class ImageDataset(Dataset):
def __init__(self, bits: bytes): def __init__(self, bits: bytes):
...@@ -118,7 +221,8 @@ class ImageDataset(Dataset): ...@@ -118,7 +221,8 @@ class ImageDataset(Dataset):
bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc. bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
""" """
pdf_bytes = fitz.open(stream=bits).convert_to_pdf() pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
self._records = [Doc(v) for v in fitz.open('pdf', pdf_bytes)] self._raw_fitz = fitz.open('pdf', pdf_bytes)
self._records = [Doc(v) for v in self._raw_fitz]
self._raw_data = bits self._raw_data = bits
self._data_bits = pdf_bytes self._data_bits = pdf_bytes
...@@ -153,14 +257,50 @@ class ImageDataset(Dataset): ...@@ -153,14 +257,50 @@ class ImageDataset(Dataset):
""" """
return self._records[page_id] return self._records[page_id]
def dump_to_file(self, file_path: str):
"""Dump the file
Args:
file_path (str): the file path
"""
dir_name = os.path.dirname(file_path)
if dir_name not in ('', '.', '..'):
os.makedirs(dir_name, exist_ok=True)
self._raw_fitz.save(file_path)
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(dataset, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
return proc(self, *args, **kwargs)
def classify(self) -> SupportedPdfParseMethod:
"""classify the dataset
Returns:
SupportedPdfParseMethod: _description_
"""
return SupportedPdfParseMethod.OCR
def clone(self):
"""clone this dataset
"""
return ImageDataset(self._raw_data)
class Doc(PageableData): class Doc(PageableData):
"""Initialized with pymudoc object.""" """Initialized with pymudoc object."""
def __init__(self, doc: fitz.Page): def __init__(self, doc: fitz.Page):
self._doc = doc self._doc = doc
def get_image(self): def get_image(self):
"""Return the imge info. """Return the image info.
Returns: Returns:
dict: { dict: {
...@@ -192,3 +332,34 @@ class Doc(PageableData): ...@@ -192,3 +332,34 @@ class Doc(PageableData):
def __getattr__(self, name): def __getattr__(self, name):
if hasattr(self._doc, name): if hasattr(self._doc, name):
return getattr(self._doc, name) return getattr(self._doc, name)
def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
"""draw rectangle.
Args:
rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
color (list[float] | None): three element tuple which describe the RGB of the board line, None means no board line
fill (list[float] | None): fill the board with RGB, None means will not fill with color
fill_opacity (float): opacity of the fill, range from [0, 1]
width (float): the width of board
overlay (bool): fill the color in foreground or background. True means fill in background.
"""
self._doc.draw_rect(
rect_coords,
color=color,
fill=fill,
fill_opacity=fill_opacity,
width=width,
overlay=overlay,
)
def insert_text(self, coord, content, fontsize, color):
"""insert text.
Args:
coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
content (str): the text content
fontsize (int): font size of the text
color (list[float] | None): three element tuple which describe the RGB of the board line, None will use the default font color!
"""
self._doc.insert_text(coord, content, fontsize=fontsize, color=color)
from magic_pdf.config.drop_reason import DropReason
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.filter.pdf_classify_by_type import classify as do_classify
from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
def classify(pdf_bytes: bytes) -> SupportedPdfParseMethod:
"""根据pdf的元数据,判断是文本pdf,还是ocr pdf."""
pdf_meta = pdf_meta_scan(pdf_bytes)
if pdf_meta.get('_need_drop', False): # 如果返回了需要丢弃的标志,则抛出异常
raise Exception(f"pdf meta_scan need_drop,reason is {pdf_meta['_drop_reason']}")
else:
is_encrypted = pdf_meta['is_encrypted']
is_needs_password = pdf_meta['is_needs_password']
if is_encrypted or is_needs_password: # 加密的,需要密码的,没有页面的,都不处理
raise Exception(f'pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}')
else:
is_text_pdf, results = do_classify(
pdf_meta['total_page'],
pdf_meta['page_width_pts'],
pdf_meta['page_height_pts'],
pdf_meta['image_info_per_page'],
pdf_meta['text_len_per_page'],
pdf_meta['imgs_per_page'],
pdf_meta['text_layout_per_page'],
pdf_meta['invalid_chars'],
)
if is_text_pdf:
return SupportedPdfParseMethod.TXT
else:
return SupportedPdfParseMethod.OCR
import fitz import fitz
from magic_pdf.config.constants import CROSS_PAGE from magic_pdf.config.constants import CROSS_PAGE
from magic_pdf.config.ocr_content_type import BlockType, CategoryId, ContentType from magic_pdf.config.ocr_content_type import (BlockType, CategoryId,
from magic_pdf.data.dataset import PymuDocDataset ContentType)
from magic_pdf.data.dataset import Dataset
from magic_pdf.model.magic_model import MagicModel from magic_pdf.model.magic_model import MagicModel
...@@ -194,7 +195,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -194,7 +195,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
) )
# Save the PDF # Save the PDF
pdf_docs.save(f'{out_path}/{filename}_layout.pdf') pdf_docs.save(f'{out_path}/{filename}')
def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename): def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
...@@ -282,18 +283,17 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -282,18 +283,17 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False) draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False)
# Save the PDF # Save the PDF
pdf_docs.save(f'{out_path}/{filename}_spans.pdf') pdf_docs.save(f'{out_path}/{filename}')
def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename): def draw_model_bbox(model_list, dataset: Dataset, out_path, filename):
dropped_bbox_list = [] dropped_bbox_list = []
tables_body_list, tables_caption_list, tables_footnote_list = [], [], [] tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], [] imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
titles_list = [] titles_list = []
texts_list = [] texts_list = []
interequations_list = [] interequations_list = []
pdf_docs = fitz.open('pdf', pdf_bytes) magic_model = MagicModel(model_list, dataset)
magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
for i in range(len(model_list)): for i in range(len(model_list)):
page_dropped_list = [] page_dropped_list = []
tables_body, tables_caption, tables_footnote = [], [], [] tables_body, tables_caption, tables_footnote = [], [], []
...@@ -337,7 +337,8 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -337,7 +337,8 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
dropped_bbox_list.append(page_dropped_list) dropped_bbox_list.append(page_dropped_list)
imgs_footnote_list.append(imgs_footnote) imgs_footnote_list.append(imgs_footnote)
for i, page in enumerate(pdf_docs): for i in range(len(dataset)):
page = dataset.get_page(i)
draw_bbox_with_number( draw_bbox_with_number(
i, dropped_bbox_list, page, [158, 158, 158], True i, dropped_bbox_list, page, [158, 158, 158], True
) # color ! ) # color !
...@@ -352,7 +353,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename): ...@@ -352,7 +353,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True) draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
# Save the PDF # Save the PDF
pdf_docs.save(f'{out_path}/{filename}_model.pdf') dataset.dump_to_file(f'{out_path}/{filename}')
def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename): def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
......
from typing import Callable
from abc import ABC, abstractmethod
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.pipe.operators import PipeResult
__use_inside_model__ = True __use_inside_model__ = True
__model_mode__ = "full" __model_mode__ = "full"
class InferenceResultBase(ABC):
@abstractmethod
def __init__(self, inference_results: list, dataset: Dataset):
"""Initialized method.
Args:
inference_results (list): the inference result generated by model
dataset (Dataset): the dataset related with model inference result
"""
self._infer_res = inference_results
self._dataset = dataset
@abstractmethod
def draw_model(self, file_path: str) -> None:
"""Draw model inference result.
Args:
file_path (str): the output file path
"""
pass
@abstractmethod
def dump_model(self, writer: DataWriter, file_path: str):
"""Dump model inference result to file.
Args:
writer (DataWriter): writer handle
file_path (str): the location of target file
"""
pass
@abstractmethod
def get_infer_res(self):
"""Get the inference result.
Returns:
list: the inference result generated by model
"""
pass
@abstractmethod
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(inference_result, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
pass
@abstractmethod
def pipe_auto_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result.
step1: classify the dataset type
step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pass
@abstractmethod
def pipe_txt_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result, Extract the text using the
third library, such as `pymupdf`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pass
@abstractmethod
def pipe_ocr_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
pass
import time import time
import fitz import fitz
import numpy as np import numpy as np
from loguru import logger from loguru import logger
import magic_pdf.model as model_config
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_formula_config get_layout_config,
get_local_models_dir,
get_table_recog_config)
from magic_pdf.model.model_list import MODEL from magic_pdf.model.model_list import MODEL
import magic_pdf.model as model_config from magic_pdf.model.operators import InferenceResult
def dict_compare(d1, d2): def dict_compare(d1, d2):
...@@ -19,25 +24,31 @@ def remove_duplicates_dicts(lst): ...@@ -19,25 +24,31 @@ def remove_duplicates_dicts(lst):
unique_dicts = [] unique_dicts = []
for dict_item in lst: for dict_item in lst:
if not any( if not any(
dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
): ):
unique_dicts.append(dict_item) unique_dicts.append(dict_item)
return unique_dicts return unique_dicts
def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list: def load_images_from_pdf(
pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None
) -> list:
try: try:
from PIL import Image from PIL import Image
except ImportError: except ImportError:
logger.error("Pillow not installed, please install by pip.") logger.error('Pillow not installed, please install by pip.')
exit(1) exit(1)
images = [] images = []
with fitz.open("pdf", pdf_bytes) as doc: with fitz.open('pdf', pdf_bytes) as doc:
pdf_page_num = doc.page_count pdf_page_num = doc.page_count
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1 end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else pdf_page_num - 1
)
if end_page_id > pdf_page_num - 1: if end_page_id > pdf_page_num - 1:
logger.warning("end_page_id is out of range, use images length") logger.warning('end_page_id is out of range, use images length')
end_page_id = pdf_page_num - 1 end_page_id = pdf_page_num - 1
for index in range(0, doc.page_count): for index in range(0, doc.page_count):
...@@ -50,11 +61,11 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id ...@@ -50,11 +61,11 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
if pm.width > 4500 or pm.height > 4500: if pm.width > 4500 or pm.height > 4500:
pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples) img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
img = np.array(img) img = np.array(img)
img_dict = {"img": img, "width": pm.width, "height": pm.height} img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
else: else:
img_dict = {"img": [], "width": 0, "height": 0} img_dict = {'img': [], 'width': 0, 'height': 0}
images.append(img_dict) images.append(img_dict)
return images return images
...@@ -69,117 +80,150 @@ class ModelSingleton: ...@@ -69,117 +80,150 @@ class ModelSingleton:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cls._instance
def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None): def get_model(
self,
ocr: bool,
show_log: bool,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
key = (ocr, show_log, lang, layout_model, formula_enable, table_enable) key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
if key not in self._models: if key not in self._models:
self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model, self._models[key] = custom_model_init(
formula_enable=formula_enable, table_enable=table_enable) ocr=ocr,
show_log=show_log,
lang=lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
return self._models[key] return self._models[key]
def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None, def custom_model_init(
layout_model=None, formula_enable=None, table_enable=None): ocr: bool = False,
show_log: bool = False,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
model = None model = None
if model_config.__model_mode__ == "lite": if model_config.__model_mode__ == 'lite':
logger.warning("The Lite mode is provided for developers to conduct testing only, and the output quality is " logger.warning(
"not guaranteed to be reliable.") 'The Lite mode is provided for developers to conduct testing only, and the output quality is '
'not guaranteed to be reliable.'
)
model = MODEL.Paddle model = MODEL.Paddle
elif model_config.__model_mode__ == "full": elif model_config.__model_mode__ == 'full':
model = MODEL.PEK model = MODEL.PEK
if model_config.__use_inside_model__: if model_config.__use_inside_model__:
model_init_start = time.time() model_init_start = time.time()
if model == MODEL.Paddle: if model == MODEL.Paddle:
from magic_pdf.model.pp_structure_v2 import CustomPaddleModel from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang) custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
elif model == MODEL.PEK: elif model == MODEL.PEK:
from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.pdf_extract_kit import CustomPEKModel
# 从配置文件读取model-dir和device # 从配置文件读取model-dir和device
local_models_dir = get_local_models_dir() local_models_dir = get_local_models_dir()
device = get_device() device = get_device()
layout_config = get_layout_config() layout_config = get_layout_config()
if layout_model is not None: if layout_model is not None:
layout_config["model"] = layout_model layout_config['model'] = layout_model
formula_config = get_formula_config() formula_config = get_formula_config()
if formula_enable is not None: if formula_enable is not None:
formula_config["enable"] = formula_enable formula_config['enable'] = formula_enable
table_config = get_table_recog_config() table_config = get_table_recog_config()
if table_enable is not None: if table_enable is not None:
table_config["enable"] = table_enable table_config['enable'] = table_enable
model_input = { model_input = {
"ocr": ocr, 'ocr': ocr,
"show_log": show_log, 'show_log': show_log,
"models_dir": local_models_dir, 'models_dir': local_models_dir,
"device": device, 'device': device,
"table_config": table_config, 'table_config': table_config,
"layout_config": layout_config, 'layout_config': layout_config,
"formula_config": formula_config, 'formula_config': formula_config,
"lang": lang, 'lang': lang,
} }
custom_model = CustomPEKModel(**model_input) custom_model = CustomPEKModel(**model_input)
else: else:
logger.error("Not allow model_name!") logger.error('Not allow model_name!')
exit(1) exit(1)
model_init_cost = time.time() - model_init_start model_init_cost = time.time() - model_init_start
logger.info(f"model init cost: {model_init_cost}") logger.info(f'model init cost: {model_init_cost}')
else: else:
logger.error("use_inside_model is False, not allow to use inside model") logger.error('use_inside_model is False, not allow to use inside model')
exit(1) exit(1)
return custom_model return custom_model
def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, def doc_analyze(
start_page_id=0, end_page_id=None, lang=None, dataset: Dataset,
layout_model=None, formula_enable=None, table_enable=None): ocr: bool = False,
show_log: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
) -> InferenceResult:
if lang == "": if lang == '':
lang = None lang = None
model_manager = ModelSingleton() model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable) custom_model = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable
with fitz.open("pdf", pdf_bytes) as doc: )
pdf_page_num = doc.page_count
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
if end_page_id > pdf_page_num - 1:
logger.warning("end_page_id is out of range, use images length")
end_page_id = pdf_page_num - 1
images = load_images_from_pdf(pdf_bytes, start_page_id=start_page_id, end_page_id=end_page_id)
model_json = [] model_json = []
doc_analyze_start = time.time() doc_analyze_start = time.time()
for index, img_dict in enumerate(images): if end_page_id is None:
img = img_dict["img"] end_page_id = len(dataset)
page_width = img_dict["width"]
page_height = img_dict["height"] for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
img = img_dict['img']
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id: if start_page_id <= index <= end_page_id:
page_start = time.time() page_start = time.time()
result = custom_model(img) result = custom_model(img)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----') logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
else: else:
result = [] result = []
page_info = {"page_no": index, "height": page_height, "width": page_width}
page_dict = {"layout_dets": result, "page_info": page_info} page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict) model_json.append(page_dict)
gc_start = time.time() gc_start = time.time()
clean_memory() clean_memory()
gc_time = round(time.time() - gc_start, 2) gc_time = round(time.time() - gc_start, 2)
logger.info(f"gc time: {gc_time}") logger.info(f'gc time: {gc_time}')
doc_analyze_time = round(time.time() - doc_analyze_start, 2) doc_analyze_time = round(time.time() - doc_analyze_start, 2)
doc_analyze_speed = round( (end_page_id + 1 - start_page_id) / doc_analyze_time, 2) doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
logger.info(f"doc analyze time: {round(time.time() - doc_analyze_start, 2)}," logger.info(
f" speed: {doc_analyze_speed} pages/second") f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
f' speed: {doc_analyze_speed} pages/second'
)
return model_json return InferenceResult(model_json, dataset)
import copy
import json
import os
from typing import Callable
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.filter import classify
from magic_pdf.libs.draw_bbox import draw_model_bbox
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
from magic_pdf.pipe.operators import PipeResult
from magic_pdf.model import InferenceResultBase
class InferenceResult(InferenceResultBase):
def __init__(self, inference_results: list, dataset: Dataset):
"""Initialized method.
Args:
inference_results (list): the inference result generated by model
dataset (Dataset): the dataset related with model inference result
"""
self._infer_res = inference_results
self._dataset = dataset
def draw_model(self, file_path: str) -> None:
"""Draw model inference result.
Args:
file_path (str): the output file path
"""
dir_name = os.path.dirname(file_path)
base_name = os.path.basename(file_path)
if not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
draw_model_bbox(
copy.deepcopy(self._infer_res), self._dataset, dir_name, base_name
)
def dump_model(self, writer: DataWriter, file_path: str):
"""Dump model inference result to file.
Args:
writer (DataWriter): writer handle
file_path (str): the location of target file
"""
writer.write_string(
file_path, json.dumps(self._infer_res, ensure_ascii=False, indent=4)
)
def get_infer_res(self):
"""Get the inference result.
Returns:
list: the inference result generated by model
"""
return self._infer_res
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(inference_result, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
def pipe_auto_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result.
step1: classify the dataset type
step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pdf_proc_method = classify(self._dataset.data_bits())
if pdf_proc_method == SupportedPdfParseMethod.TXT:
return self.pipe_txt_mode(
imageWriter, start_page_id, end_page_id, debug_mode, lang
)
else:
return self.pipe_ocr_mode(
imageWriter, start_page_id, end_page_id, debug_mode, lang
)
def pipe_txt_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result, Extract the text using the
third library, such as `pymupdf`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
def proc(*args, **kwargs) -> PipeResult:
res = pdf_parse_union(*args, **kwargs)
return PipeResult(res, self._dataset)
return self.apply(
proc,
self._dataset,
imageWriter,
SupportedPdfParseMethod.TXT,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
lang=lang,
)
def pipe_ocr_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result, Extract the text using `OCR`
technical.
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
def proc(*args, **kwargs) -> PipeResult:
res = pdf_parse_union(*args, **kwargs)
return PipeResult(res, self._dataset)
return self.apply(
proc,
self._dataset,
imageWriter,
SupportedPdfParseMethod.TXT,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
lang=lang,
)
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.data.dataset import Dataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
def parse_pdf_by_ocr(pdf_bytes, def parse_pdf_by_ocr(dataset: Dataset,
model_list, model_list,
imageWriter, imageWriter,
start_page_id=0, start_page_id=0,
...@@ -11,9 +11,8 @@ def parse_pdf_by_ocr(pdf_bytes, ...@@ -11,9 +11,8 @@ def parse_pdf_by_ocr(pdf_bytes,
debug_mode=False, debug_mode=False,
lang=None, lang=None,
): ):
dataset = PymuDocDataset(pdf_bytes) return pdf_parse_union(model_list,
return pdf_parse_union(dataset, dataset,
model_list,
imageWriter, imageWriter,
SupportedPdfParseMethod.OCR, SupportedPdfParseMethod.OCR,
start_page_id=start_page_id, start_page_id=start_page_id,
......
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.data.dataset import Dataset
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
def parse_pdf_by_txt( def parse_pdf_by_txt(
pdf_bytes, dataset: Dataset,
model_list, model_list,
imageWriter, imageWriter,
start_page_id=0, start_page_id=0,
...@@ -12,9 +12,8 @@ def parse_pdf_by_txt( ...@@ -12,9 +12,8 @@ def parse_pdf_by_txt(
debug_mode=False, debug_mode=False,
lang=None, lang=None,
): ):
dataset = PymuDocDataset(pdf_bytes) return pdf_parse_union(model_list,
return pdf_parse_union(dataset, dataset,
model_list,
imageWriter, imageWriter,
SupportedPdfParseMethod.TXT, SupportedPdfParseMethod.TXT,
start_page_id=start_page_id, start_page_id=start_page_id,
......
...@@ -4,8 +4,8 @@ import statistics ...@@ -4,8 +4,8 @@ import statistics
import time import time
from typing import List from typing import List
import torch
import fitz import fitz
import torch
from loguru import logger from loguru import logger
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
...@@ -16,17 +16,13 @@ from magic_pdf.libs.clean_memory import clean_memory ...@@ -16,17 +16,13 @@ from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
from magic_pdf.libs.convert_utils import dict_to_list from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.hash_utils import compute_md5 from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
from magic_pdf.model.magic_model import MagicModel from magic_pdf.model.magic_model import MagicModel
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try: try:
import torchtext import torchtext
if torchtext.__version__ >= "0.18.0": if torchtext.__version__ >= '0.18.0':
torchtext.disable_torchtext_deprecation_warning() torchtext.disable_torchtext_deprecation_warning()
except ImportError: except ImportError:
pass pass
...@@ -39,6 +35,9 @@ from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layo ...@@ -39,6 +35,9 @@ from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layo
from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans_v2, fix_discarded_block from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans_v2, fix_discarded_block
from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans, remove_overlaps_min_spans from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans, remove_overlaps_min_spans
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
def __replace_STX_ETX(text_str: str): def __replace_STX_ETX(text_str: str):
"""Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks. """Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
...@@ -241,7 +240,7 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang ...@@ -241,7 +240,7 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
for span in empty_spans: for span in empty_spans:
# 对span的bbox截图再ocr # 对span的bbox截图再ocr
span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode="cv2") span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode='cv2')
ocr_res = ocr_model.ocr(span_img, det=False) ocr_res = ocr_model.ocr(span_img, det=False)
if ocr_res and len(ocr_res) > 0: if ocr_res and len(ocr_res) > 0:
if len(ocr_res[0]) > 0: if len(ocr_res[0]) > 0:
...@@ -681,7 +680,7 @@ def parse_page_core( ...@@ -681,7 +680,7 @@ def parse_page_core(
"""根据parse_mode,构造spans,主要是文本类的字符填充""" """根据parse_mode,构造spans,主要是文本类的字符填充"""
if parse_mode == SupportedPdfParseMethod.TXT: if parse_mode == SupportedPdfParseMethod.TXT:
"""使用新版本的混合ocr方案""" """使用新版本的混合ocr方案."""
spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, lang) spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, lang)
elif parse_mode == SupportedPdfParseMethod.OCR: elif parse_mode == SupportedPdfParseMethod.OCR:
...@@ -689,7 +688,6 @@ def parse_page_core( ...@@ -689,7 +688,6 @@ def parse_page_core(
else: else:
raise Exception('parse_mode must be txt or ocr') raise Exception('parse_mode must be txt or ocr')
"""先处理不需要排版的discarded_blocks""" """先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks( discarded_block_with_spans, spans = fill_spans_in_blocks(
all_discarded_blocks, spans, 0.4 all_discarded_blocks, spans, 0.4
...@@ -762,8 +760,8 @@ def parse_page_core( ...@@ -762,8 +760,8 @@ def parse_page_core(
def pdf_parse_union( def pdf_parse_union(
dataset: Dataset,
model_list, model_list,
dataset: Dataset,
imageWriter, imageWriter,
parse_mode, parse_mode,
start_page_id=0, start_page_id=0,
......
...@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod ...@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from magic_pdf.config.drop_reason import DropReason from magic_pdf.config.drop_reason import DropReason
from magic_pdf.config.make_content_config import DropMode, MakeMode from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.dict2md.ocr_mkcontent import union_make from magic_pdf.dict2md.ocr_mkcontent import union_make
from magic_pdf.filter.pdf_classify_by_type import classify from magic_pdf.filter.pdf_classify_by_type import classify
from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
...@@ -14,9 +15,9 @@ class AbsPipe(ABC): ...@@ -14,9 +15,9 @@ class AbsPipe(ABC):
PIP_OCR = 'ocr' PIP_OCR = 'ocr'
PIP_TXT = 'txt' PIP_TXT = 'txt'
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False, def __init__(self, dataset: Dataset, model_list: list, image_writer: DataWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None): start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
self.pdf_bytes = pdf_bytes self.dataset = Dataset
self.model_list = model_list self.model_list = model_list
self.image_writer = image_writer self.image_writer = image_writer
self.pdf_mid_data = None # 未压缩 self.pdf_mid_data = None # 未压缩
......
...@@ -2,40 +2,79 @@ from loguru import logger ...@@ -2,40 +2,79 @@ from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.AbsPipe import AbsPipe from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_ocr_pdf from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe): class OCRPipe(AbsPipe):
def __init__(
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False, self,
start_page_id=0, end_page_id=None, lang=None, dataset: Dataset,
layout_model=None, formula_enable=None, table_enable=None): model_list: list,
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang, image_writer: DataWriter,
layout_model, formula_enable, table_enable) is_debug: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
super().__init__(
dataset,
model_list,
image_writer,
is_debug,
start_page_id,
end_page_id,
lang,
layout_model,
formula_enable,
table_enable,
)
def pipe_classify(self): def pipe_classify(self):
pass pass
def pipe_analyze(self): def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=True, self.infer_res = doc_analyze(
start_page_id=self.start_page_id, end_page_id=self.end_page_id, self.dataset,
lang=self.lang, layout_model=self.layout_model, ocr=True,
formula_enable=self.formula_enable, table_enable=self.table_enable) start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
def pipe_parse(self): def pipe_parse(self):
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, self.pdf_mid_data = parse_ocr_pdf(
start_page_id=self.start_page_id, end_page_id=self.end_page_id, self.dataset,
lang=self.lang, layout_model=self.layout_model, self.infer_res,
formula_enable=self.formula_enable, table_enable=self.table_enable) self.image_writer,
is_debug=self.is_debug,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info('ocr_pipe mk content list finished') logger.info('ocr_pipe mk content list finished')
return result return result
def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD): def pipe_mk_markdown(
self,
img_parent_path: str,
drop_mode=DropMode.WHOLE_PDF,
md_make_mode=MakeMode.MM_MD,
):
result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode) result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
logger.info(f'ocr_pipe mk {md_make_mode} finished') logger.info(f'ocr_pipe mk {md_make_mode} finished')
return result return result
...@@ -2,6 +2,7 @@ from loguru import logger ...@@ -2,6 +2,7 @@ from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.AbsPipe import AbsPipe from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_txt_pdf from magic_pdf.user_api import parse_txt_pdf
...@@ -9,23 +10,23 @@ from magic_pdf.user_api import parse_txt_pdf ...@@ -9,23 +10,23 @@ from magic_pdf.user_api import parse_txt_pdf
class TXTPipe(AbsPipe): class TXTPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False, def __init__(self, dataset: Dataset, model_list: list, image_writer: DataWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None, start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None): layout_model=None, formula_enable=None, table_enable=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang, super().__init__(dataset, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
layout_model, formula_enable, table_enable) layout_model, formula_enable, table_enable)
def pipe_classify(self): def pipe_classify(self):
pass pass
def pipe_analyze(self): def pipe_analyze(self):
self.model_list = doc_analyze(self.pdf_bytes, ocr=False, self.model_list = doc_analyze(self.dataset, ocr=False,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model, lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable) formula_enable=self.formula_enable, table_enable=self.table_enable)
def pipe_parse(self): def pipe_parse(self):
self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug, self.pdf_mid_data = parse_txt_pdf(self.dataset, self.model_list, self.image_writer, is_debug=self.is_debug,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, start_page_id=self.start_page_id, end_page_id=self.end_page_id,
lang=self.lang, layout_model=self.layout_model, lang=self.lang, layout_model=self.layout_model,
formula_enable=self.formula_enable, table_enable=self.table_enable) formula_enable=self.formula_enable, table_enable=self.table_enable)
......
...@@ -4,6 +4,7 @@ from loguru import logger ...@@ -4,6 +4,7 @@ from loguru import logger
from magic_pdf.config.make_content_config import DropMode, MakeMode from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.commons import join_path from magic_pdf.libs.commons import join_path
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.AbsPipe import AbsPipe from magic_pdf.pipe.AbsPipe import AbsPipe
...@@ -12,12 +13,32 @@ from magic_pdf.user_api import parse_ocr_pdf, parse_union_pdf ...@@ -12,12 +13,32 @@ from magic_pdf.user_api import parse_ocr_pdf, parse_union_pdf
class UNIPipe(AbsPipe): class UNIPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: DataWriter, is_debug: bool = False, def __init__(
start_page_id=0, end_page_id=None, lang=None, self,
layout_model=None, formula_enable=None, table_enable=None): dataset: Dataset,
jso_useful_key: dict,
image_writer: DataWriter,
is_debug: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
):
self.pdf_type = jso_useful_key['_pdf_type'] self.pdf_type = jso_useful_key['_pdf_type']
super().__init__(pdf_bytes, jso_useful_key['model_list'], image_writer, is_debug, start_page_id, end_page_id, super().__init__(
lang, layout_model, formula_enable, table_enable) dataset,
jso_useful_key['model_list'],
image_writer,
is_debug,
start_page_id,
end_page_id,
lang,
layout_model,
formula_enable,
table_enable,
)
if len(self.model_list) == 0: if len(self.model_list) == 0:
self.input_model_is_empty = True self.input_model_is_empty = True
else: else:
...@@ -28,35 +49,66 @@ class UNIPipe(AbsPipe): ...@@ -28,35 +49,66 @@ class UNIPipe(AbsPipe):
def pipe_analyze(self): def pipe_analyze(self):
if self.pdf_type == self.PIP_TXT: if self.pdf_type == self.PIP_TXT:
self.model_list = doc_analyze(self.pdf_bytes, ocr=False, self.model_list = doc_analyze(
start_page_id=self.start_page_id, end_page_id=self.end_page_id, self.dataset,
lang=self.lang, layout_model=self.layout_model, ocr=False,
formula_enable=self.formula_enable, table_enable=self.table_enable) start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
elif self.pdf_type == self.PIP_OCR: elif self.pdf_type == self.PIP_OCR:
self.model_list = doc_analyze(self.pdf_bytes, ocr=True, self.model_list = doc_analyze(
start_page_id=self.start_page_id, end_page_id=self.end_page_id, self.dataset,
lang=self.lang, layout_model=self.layout_model, ocr=True,
formula_enable=self.formula_enable, table_enable=self.table_enable) start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
def pipe_parse(self): def pipe_parse(self):
if self.pdf_type == self.PIP_TXT: if self.pdf_type == self.PIP_TXT:
self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer, self.pdf_mid_data = parse_union_pdf(
is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty, self.dataset,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, self.model_list,
lang=self.lang, layout_model=self.layout_model, self.image_writer,
formula_enable=self.formula_enable, table_enable=self.table_enable) is_debug=self.is_debug,
start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
layout_model=self.layout_model,
formula_enable=self.formula_enable,
table_enable=self.table_enable,
)
elif self.pdf_type == self.PIP_OCR: elif self.pdf_type == self.PIP_OCR:
self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, self.pdf_mid_data = parse_ocr_pdf(
is_debug=self.is_debug, self.dataset,
start_page_id=self.start_page_id, end_page_id=self.end_page_id, self.model_list,
lang=self.lang) self.image_writer,
is_debug=self.is_debug,
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON): start_page_id=self.start_page_id,
end_page_id=self.end_page_id,
lang=self.lang,
)
def pipe_mk_uni_format(
self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON
):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info('uni_pipe mk content list finished') logger.info('uni_pipe mk content list finished')
return result return result
def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD): def pipe_mk_markdown(
self,
img_parent_path: str,
drop_mode=DropMode.WHOLE_PDF,
md_make_mode=MakeMode.MM_MD,
):
result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode) result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
logger.info(f'uni_pipe mk {md_make_mode} finished') logger.info(f'uni_pipe mk {md_make_mode} finished')
return result return result
...@@ -65,6 +117,7 @@ class UNIPipe(AbsPipe): ...@@ -65,6 +117,7 @@ class UNIPipe(AbsPipe):
if __name__ == '__main__': if __name__ == '__main__':
# 测试 # 测试
from magic_pdf.data.data_reader_writer import DataReader from magic_pdf.data.data_reader_writer import DataReader
drw = DataReader(r'D:/project/20231108code-clean') drw = DataReader(r'D:/project/20231108code-clean')
pdf_file_path = r'linshixuqiu\19983-00.pdf' pdf_file_path = r'linshixuqiu\19983-00.pdf'
...@@ -82,10 +135,7 @@ if __name__ == '__main__': ...@@ -82,10 +135,7 @@ if __name__ == '__main__':
# "model_list": model_list # "model_list": model_list
# } # }
jso_useful_key = { jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
'_pdf_type': '',
'model_list': model_list
}
pipe = UNIPipe(pdf_bytes, jso_useful_key, img_writer) pipe = UNIPipe(pdf_bytes, jso_useful_key, img_writer)
pipe.pipe_classify() pipe.pipe_classify()
pipe.pipe_parse() pipe.pipe_parse()
...@@ -94,5 +144,7 @@ if __name__ == '__main__': ...@@ -94,5 +144,7 @@ if __name__ == '__main__':
md_writer = DataWriter(write_path) md_writer = DataWriter(write_path)
md_writer.write_string('19983-00.md', md_content) md_writer.write_string('19983-00.md', md_content)
md_writer.write_string('19983-00.json', json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4)) md_writer.write_string(
'19983-00.json', json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4)
)
md_writer.write_string('19983-00.txt', str(content_list)) md_writer.write_string('19983-00.txt', str(content_list))
import json
import os
from typing import Callable
import copy
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.dict2md.ocr_mkcontent import union_make
from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
draw_span_bbox)
from magic_pdf.libs.json_compressor import JsonCompressor
class PipeResult:
def __init__(self, pipe_res, dataset: Dataset):
"""Initialized.
Args:
pipe_res (list[dict]): the pipeline processed result of model inference result
dataset (Dataset): the dataset associated with pipe_res
"""
self._pipe_res = pipe_res
self._dataset = dataset
def dump_md(
self,
writer: DataWriter,
file_path: str,
img_dir_or_bucket_prefix: str,
drop_mode=DropMode.WHOLE_PDF,
md_make_mode=MakeMode.MM_MD,
):
"""Dump The Markdown.
Args:
writer (DataWriter): File writer handle
file_path (str): The file location of markdown
img_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.WHOLE_PDF.
md_make_mode (str, optional): The content Type of Markdown be made. Defaults to MakeMode.MM_MD.
"""
pdf_info_list = self._pipe_res['pdf_info']
md_content = union_make(
pdf_info_list, md_make_mode, drop_mode, img_dir_or_bucket_prefix
)
writer.write_string(file_path, md_content)
def dump_content_list(
self, writer: DataWriter, file_path: str, image_dir_or_bucket_prefix: str
):
"""Dump Content List.
Args:
writer (DataWriter): File writer handle
file_path (str): The file location of content list
image_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
"""
pdf_info_list = self._pipe_res['pdf_info']
content_list = union_make(
pdf_info_list,
MakeMode.STANDARD_FORMAT,
DropMode.NONE,
image_dir_or_bucket_prefix,
)
writer.write_string(
file_path, json.dumps(content_list, ensure_ascii=False, indent=4)
)
def dump_middle_json(self, writer: DataWriter, file_path: str):
"""Dump the result of pipeline.
Args:
writer (DataWriter): File writer handler
file_path (str): The file location of middle json
"""
writer.write_string(
file_path, json.dumps(self._pipe_res, ensure_ascii=False, indent=4)
)
def draw_layout(self, file_path: str) -> None:
"""Draw the layout.
Args:
file_path (str): The file location of layout result file
"""
dir_name = os.path.dirname(file_path)
base_name = os.path.basename(file_path)
if not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
pdf_info = self._pipe_res['pdf_info']
draw_layout_bbox(pdf_info, self._dataset.data_bits(), dir_name, base_name)
def draw_span(self, file_path: str):
"""Draw the Span.
Args:
file_path (str): The file location of span result file
"""
dir_name = os.path.dirname(file_path)
base_name = os.path.basename(file_path)
if not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
pdf_info = self._pipe_res['pdf_info']
draw_span_bbox(pdf_info, self._dataset.data_bits(), dir_name, base_name)
def draw_line_sort(self, file_path: str):
"""Draw line sort.
Args:
file_path (str): The file location of line sort result file
"""
dir_name = os.path.dirname(file_path)
base_name = os.path.basename(file_path)
if not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
pdf_info = self._pipe_res['pdf_info']
draw_line_sort_bbox(pdf_info, self._dataset.data_bits(), dir_name, base_name)
def get_compress_pdf_mid_data(self):
"""Compress the pipeline result.
Returns:
str: compress the pipeline result and return
"""
return JsonCompressor.compress_json(self.pdf_mid_data)
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(pipeline_result, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
return proc(copy.deepcopy(self._pipe_res), *args, **kwargs)
import copy
import json as json_parse
import os import os
import click import click
...@@ -7,13 +5,12 @@ import fitz ...@@ -7,13 +5,12 @@ import fitz
from loguru import logger from loguru import logger
import magic_pdf.model as model_config import magic_pdf.model as model_config
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.config.make_content_config import DropMode, MakeMode from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import FileBasedDataWriter from magic_pdf.data.data_reader_writer import FileBasedDataWriter
from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox, from magic_pdf.data.dataset import PymuDocDataset
draw_model_bbox, draw_span_bbox) from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pipe.OCRPipe import OCRPipe from magic_pdf.model.operators import InferenceResult
from magic_pdf.pipe.TXTPipe import TXTPipe
from magic_pdf.pipe.UNIPipe import UNIPipe
# from io import BytesIO # from io import BytesIO
# from pypdf import PdfReader, PdfWriter # from pypdf import PdfReader, PdfWriter
...@@ -56,7 +53,11 @@ def prepare_env(output_dir, pdf_file_name, method): ...@@ -56,7 +53,11 @@ def prepare_env(output_dir, pdf_file_name, method):
def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_id=None): def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_id=None):
document = fitz.open('pdf', pdf_bytes) document = fitz.open('pdf', pdf_bytes)
output_document = fitz.open() output_document = fitz.open()
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(document) - 1 end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else len(document) - 1
)
if end_page_id > len(document) - 1: if end_page_id > len(document) - 1:
logger.warning('end_page_id is out of range, use pdf_docs length') logger.warning('end_page_id is out of range, use pdf_docs length')
end_page_id = len(document) - 1 end_page_id = len(document) - 1
...@@ -94,78 +95,123 @@ def do_parse( ...@@ -94,78 +95,123 @@ def do_parse(
f_draw_model_bbox = True f_draw_model_bbox = True
f_draw_line_sort_bbox = True f_draw_line_sort_bbox = True
if lang == "": if lang == '':
lang = None lang = None
pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id, end_page_id) pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
pdf_bytes, start_page_id, end_page_id
)
orig_model_list = copy.deepcopy(model_list) local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name,
parse_method)
image_writer, md_writer = FileBasedDataWriter( image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
local_image_dir), FileBasedDataWriter(local_md_dir) local_md_dir
)
image_dir = str(os.path.basename(local_image_dir)) image_dir = str(os.path.basename(local_image_dir))
if parse_method == 'auto': ds = PymuDocDataset(pdf_bytes)
jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
# start_page_id=start_page_id, end_page_id=end_page_id,
lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
elif parse_method == 'txt':
pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
# start_page_id=start_page_id, end_page_id=end_page_id,
lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
elif parse_method == 'ocr':
pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
# start_page_id=start_page_id, end_page_id=end_page_id,
lang=lang,
layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
else:
logger.error('unknown parse method')
exit(1)
pipe.pipe_classify()
if len(model_list) == 0: if len(model_list) == 0:
if model_config.__use_inside_model__: if model_config.__use_inside_model__:
pipe.pipe_analyze() if parse_method == 'auto':
orig_model_list = copy.deepcopy(pipe.model_list) if ds.classify() == SupportedPdfParseMethod.TXT:
infer_result = ds.apply(
doc_analyze,
ocr=False,
lang=lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
else:
infer_result = ds.apply(
doc_analyze,
ocr=True,
lang=lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
pipe_result = infer_result.pipe_auto_mode(
image_writer, debug_mode=True, lang=lang
)
elif parse_method == 'txt':
infer_result = ds.apply(
doc_analyze,
ocr=False,
lang=lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
pipe_result = infer_result.pipe_txt_mode(
image_writer, debug_mode=True, lang=lang
)
elif parse_method == 'ocr':
infer_result = ds.apply(
doc_analyze,
ocr=True,
lang=lang,
layout_model=layout_model,
formula_enable=formula_enable,
table_enable=table_enable,
)
pipe_result = infer_result.pipe_ocr_mode(
image_writer, debug_mode=True, lang=lang
)
else:
logger.error('unknown parse method')
exit(1)
else: else:
logger.error('need model list input') logger.error('need model list input')
exit(2) exit(2)
else:
infer_result = InferenceResult(model_list, ds)
if parse_method == 'ocr':
pipe_result = infer_result.pipe_ocr_mode(
image_writer, debug_mode=True, lang=lang
)
elif parse_method == 'txt':
pipe_result = infer_result.pipe_txt_mode(
image_writer, debug_mode=True, lang=lang
)
else:
pipe_result = infer_result.pipe_auto_mode(
image_writer, debug_mode=True, lang=lang
)
if f_draw_model_bbox:
infer_result.draw_model(
os.path.join(local_md_dir, f'{pdf_file_name}_model.pdf')
)
pipe.pipe_parse()
pdf_info = pipe.pdf_mid_data['pdf_info']
if f_draw_layout_bbox: if f_draw_layout_bbox:
draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name) pipe_result.draw_layout(
os.path.join(local_md_dir, f'{pdf_file_name}_layout.pdf')
)
if f_draw_span_bbox: if f_draw_span_bbox:
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name) pipe_result.draw_span(os.path.join(local_md_dir, f'{pdf_file_name}_spans.pdf'))
if f_draw_model_bbox:
draw_model_bbox(copy.deepcopy(orig_model_list), pdf_bytes, local_md_dir, pdf_file_name)
if f_draw_line_sort_bbox: if f_draw_line_sort_bbox:
draw_line_sort_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name) pipe_result.draw_line_sort(
os.path.join(local_md_dir, f'{pdf_file_name}_line_sort.pdf')
)
md_content = pipe.pipe_mk_markdown(image_dir, drop_mode=DropMode.NONE, md_make_mode=f_make_md_mode)
if f_dump_md: if f_dump_md:
md_writer.write_string( pipe_result.dump_md(
md_writer,
f'{pdf_file_name}.md', f'{pdf_file_name}.md',
md_content image_dir,
drop_mode=DropMode.NONE,
md_make_mode=f_make_md_mode,
) )
if f_dump_middle_json: if f_dump_middle_json:
md_writer.write_string( pipe_result.dump_middle_json(md_writer, f'{pdf_file_name}_middle.json')
f'{pdf_file_name}_middle.json',
json_parse.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4)
)
if f_dump_model_json: if f_dump_model_json:
md_writer.write_string( infer_result.dump_model(md_writer, f'{pdf_file_name}_model.json')
f'{pdf_file_name}_model.json',
json_parse.dumps(orig_model_list, ensure_ascii=False, indent=4)
)
if f_dump_orig_pdf: if f_dump_orig_pdf:
md_writer.write( md_writer.write(
...@@ -173,11 +219,11 @@ def do_parse( ...@@ -173,11 +219,11 @@ def do_parse(
pdf_bytes, pdf_bytes,
) )
content_list = pipe.pipe_mk_uni_format(image_dir, drop_mode=DropMode.NONE)
if f_dump_content_list: if f_dump_content_list:
md_writer.write_string( pipe_result.dump_content_list(
md_writer,
f'{pdf_file_name}_content_list.json', f'{pdf_file_name}_content_list.json',
json_parse.dumps(content_list, ensure_ascii=False, indent=4) image_dir
) )
logger.info(f'local output dir is {local_md_dir}') logger.info(f'local output dir is {local_md_dir}')
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
from loguru import logger from loguru import logger
from magic_pdf.data.data_reader_writer import DataWriter from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.version import __version__ from magic_pdf.libs.version import __version__
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
...@@ -19,13 +20,21 @@ PARSE_TYPE_TXT = 'txt' ...@@ -19,13 +20,21 @@ PARSE_TYPE_TXT = 'txt'
PARSE_TYPE_OCR = 'ocr' PARSE_TYPE_OCR = 'ocr'
def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False, def parse_txt_pdf(
start_page_id=0, end_page_id=None, lang=None, dataset: Dataset,
*args, **kwargs): model_list: list,
imageWriter: DataWriter,
is_debug=False,
start_page_id=0,
end_page_id=None,
lang=None,
*args,
**kwargs
):
"""解析文本类pdf.""" """解析文本类pdf."""
pdf_info_dict = parse_pdf_by_txt( pdf_info_dict = parse_pdf_by_txt(
pdf_bytes, dataset,
pdf_models, model_list,
imageWriter, imageWriter,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
...@@ -43,13 +52,21 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, i ...@@ -43,13 +52,21 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, i
return pdf_info_dict return pdf_info_dict
def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False, def parse_ocr_pdf(
start_page_id=0, end_page_id=None, lang=None, dataset: Dataset,
*args, **kwargs): model_list: list,
imageWriter: DataWriter,
is_debug=False,
start_page_id=0,
end_page_id=None,
lang=None,
*args,
**kwargs
):
"""解析ocr类pdf.""" """解析ocr类pdf."""
pdf_info_dict = parse_pdf_by_ocr( pdf_info_dict = parse_pdf_by_ocr(
pdf_bytes, dataset,
pdf_models, model_list,
imageWriter, imageWriter,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
...@@ -67,17 +84,24 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, i ...@@ -67,17 +84,24 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, i
return pdf_info_dict return pdf_info_dict
def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False, def parse_union_pdf(
input_model_is_empty: bool = False, dataset: Dataset,
start_page_id=0, end_page_id=None, lang=None, model_list: list,
*args, **kwargs): imageWriter: DataWriter,
is_debug=False,
start_page_id=0,
end_page_id=None,
lang=None,
*args,
**kwargs
):
"""ocr和文本混合的pdf,全部解析出来.""" """ocr和文本混合的pdf,全部解析出来."""
def parse_pdf(method): def parse_pdf(method):
try: try:
return method( return method(
pdf_bytes, dataset,
pdf_models, model_list,
imageWriter, imageWriter,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
...@@ -91,12 +115,12 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, ...@@ -91,12 +115,12 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter,
pdf_info_dict = parse_pdf(parse_pdf_by_txt) pdf_info_dict = parse_pdf(parse_pdf_by_txt)
if pdf_info_dict is None or pdf_info_dict.get('_need_drop', False): if pdf_info_dict is None or pdf_info_dict.get('_need_drop', False):
logger.warning('parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr') logger.warning('parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr')
if input_model_is_empty: if len(model_list) == 0:
layout_model = kwargs.get('layout_model', None) layout_model = kwargs.get('layout_model', None)
formula_enable = kwargs.get('formula_enable', None) formula_enable = kwargs.get('formula_enable', None)
table_enable = kwargs.get('table_enable', None) table_enable = kwargs.get('table_enable', None)
pdf_models = doc_analyze( infer_res = doc_analyze(
pdf_bytes, dataset,
ocr=True, ocr=True,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
...@@ -105,6 +129,7 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, ...@@ -105,6 +129,7 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter,
formula_enable=formula_enable, formula_enable=formula_enable,
table_enable=table_enable, table_enable=table_enable,
) )
model_list = infer_res.get_infer_res()
pdf_info_dict = parse_pdf(parse_pdf_by_ocr) pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
if pdf_info_dict is None: if pdf_info_dict is None:
raise Exception('Both parse_pdf_by_txt and parse_pdf_by_ocr failed.') raise Exception('Both parse_pdf_by_txt and parse_pdf_by_ocr failed.')
......
This diff is collapsed.
...@@ -7,3 +7,5 @@ ...@@ -7,3 +7,5 @@
api/read_api api/read_api
api/schemas api/schemas
api/io api/io
api/pipe_operators
api/model_operators
\ No newline at end of file
Model Api
==========
.. autoclass:: magic_pdf.model.InferenceResultBase
:members:
:inherited-members:
:show-inheritance:
Pipeline Api
=============
.. autoclass:: magic_pdf.pipe.operators.PipeResult
:members:
:inherited-members:
:show-inheritance:
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment