Commit 4a82d6a0 authored by icecraft's avatar icecraft Committed by xu rui
Browse files

feat: add function definitions

parent a3a720ea
......@@ -32,10 +32,28 @@ class PageableData(ABC):
@abstractmethod
def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
"""draw rectangle.
Args:
rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
color (list[float] | None): three element tuple which descript the RGB of the board line, None means no board line
fill (list[float] | None): fill the board with RGB, None means will not fill with color
fill_opacity (float): opacity of the fill, range from [0, 1]
width (float): the width of board
overlay (bool): fill the color in foreground or background. True means fill in background.
"""
pass
@abstractmethod
def insert_text(self, coord, content, fontsize, color):
"""insert text.
Args:
coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
content (str): the text content
fontsize (int): font size of the text
color (list[float] | None): three element tuple which descript the RGB of the board line, None will use the default font color!
"""
pass
......@@ -244,6 +262,16 @@ class Doc(PageableData):
return getattr(self._doc, name)
def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
"""draw rectangle.
Args:
rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
color (list[float] | None): three element tuple which descript the RGB of the board line, None means no board line
fill (list[float] | None): fill the board with RGB, None means will not fill with color
fill_opacity (float): opacity of the fill, range from [0, 1]
width (float): the width of board
overlay (bool): fill the color in foreground or background. True means fill in background.
"""
self._doc.draw_rect(
rect_coords,
color=color,
......@@ -254,4 +282,12 @@ class Doc(PageableData):
)
def insert_text(self, coord, content, fontsize, color):
"""insert text.
Args:
coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
content (str): the text content
fontsize (int): font size of the text
color (list[float] | None): three element tuple which descript the RGB of the board line, None will use the default font color!
"""
self._doc.insert_text(coord, content, fontsize=fontsize, color=color)
......@@ -13,7 +13,7 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_local_models_dir,
get_table_recog_config)
from magic_pdf.model.model_list import MODEL
from magic_pdf.model.types import InferenceResult
from magic_pdf.model.operators import InferenceResult
def dict_compare(d1, d2):
......
......@@ -9,15 +9,26 @@ from magic_pdf.data.dataset import Dataset
from magic_pdf.filter import classify
from magic_pdf.libs.draw_bbox import draw_model_bbox
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
from magic_pdf.pipe.types import PipeResult
from magic_pdf.pipe.operators import PipeResult
class InferenceResult:
def __init__(self, inference_results: list, dataset: Dataset):
"""Initialized method.
Args:
inference_results (list): the inference result generated by model
dataset (Dataset): the dataset related with model inference result
"""
self._infer_res = inference_results
self._dataset = dataset
def draw_model(self, file_path: str) -> None:
"""Draw model inference result.
Args:
file_path (str): the output file path
"""
dir_name = os.path.dirname(file_path)
base_name = os.path.basename(file_path)
if not os.path.exists(dir_name):
......@@ -27,14 +38,34 @@ class InferenceResult:
)
def dump_model(self, writer: DataWriter, file_path: str):
"""Dump model inference result to file.
Args:
writer (DataWriter): writer handle
file_path (str): the location of target file
"""
writer.write_string(
file_path, json.dumps(self._infer_res, ensure_ascii=False, indent=4)
)
def get_infer_res(self):
"""Get the inference result.
Returns:
list[dict]: the inference result generated by model
"""
return self._infer_res
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(inference_result, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
def pipe_auto_mode(
......@@ -45,33 +76,30 @@ class InferenceResult:
debug_mode=False,
lang=None,
) -> PipeResult:
def proc(*args, **kwargs) -> PipeResult:
res = pdf_parse_union(*args, **kwargs)
return PipeResult(res, self._dataset)
"""Post-proc the model inference result.
step1: classify the dataset type
step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (_type_, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (_type_, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pdf_proc_method = classify(self._dataset.data_bits())
if pdf_proc_method == SupportedPdfParseMethod.TXT:
return self.apply(
proc,
self._dataset,
imageWriter,
SupportedPdfParseMethod.TXT,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
return self.pipe_txt_mode(
imageWriter, start_page_id, end_page_id, debug_mode, lang
)
else:
return self.apply(
proc,
self._dataset,
imageWriter,
SupportedPdfParseMethod.OCR,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
return self.pipe_ocr_mode(
imageWriter, start_page_id, end_page_id, debug_mode, lang
)
def pipe_txt_mode(
......@@ -82,6 +110,20 @@ class InferenceResult:
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result, Extract the text using the
third library, such as `pymupdf`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (_type_, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (_type_, optional): Defaults to None.
Returns:
PipeResult: the result
"""
def proc(*args, **kwargs) -> PipeResult:
res = pdf_parse_union(*args, **kwargs)
return PipeResult(res, self._dataset)
......@@ -91,10 +133,10 @@ class InferenceResult:
self._dataset,
imageWriter,
SupportedPdfParseMethod.TXT,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
lang=lang,
)
def pipe_ocr_mode(
......@@ -105,6 +147,19 @@ class InferenceResult:
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result, Extract the text using `OCR`
technical.
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (_type_, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (_type_, optional): Defaults to None.
Returns:
PipeResult: the result
"""
def proc(*args, **kwargs) -> PipeResult:
res = pdf_parse_union(*args, **kwargs)
......@@ -115,8 +170,8 @@ class InferenceResult:
self._dataset,
imageWriter,
SupportedPdfParseMethod.TXT,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
lang=lang,
)
......@@ -4,8 +4,8 @@ import statistics
import time
from typing import List
import torch
import fitz
import torch
from loguru import logger
from magic_pdf.config.enums import SupportedPdfParseMethod
......@@ -16,17 +16,13 @@ from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
from magic_pdf.model.magic_model import MagicModel
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try:
import torchtext
if torchtext.__version__ >= "0.18.0":
if torchtext.__version__ >= '0.18.0':
torchtext.disable_torchtext_deprecation_warning()
except ImportError:
pass
......@@ -39,6 +35,9 @@ from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layo
from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans_v2, fix_discarded_block
from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans, remove_overlaps_min_spans
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
def __replace_STX_ETX(text_str: str):
"""Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
......@@ -90,7 +89,10 @@ def chars_to_content(span):
LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';', ']', '】', '}', '}', '>', '》', '、', ',', ',', '-', '—', '–',)
<<<<<<< HEAD
LINE_START_FLAG = ('(', '(', '"', '“', '【', '{', '《', '<', '「', '『', '【', '[',)
=======
>>>>>>> 731f4bf (feat: add function definitions)
def fill_char_in_spans(spans, all_chars):
......@@ -233,7 +235,7 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
# 初始化ocr模型
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name="ocr",
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
......@@ -241,7 +243,7 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
for span in empty_spans:
# 对span的bbox截图再ocr
span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode="cv2")
span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode='cv2')
ocr_res = ocr_model.ocr(span_img, det=False)
if ocr_res and len(ocr_res) > 0:
if len(ocr_res[0]) > 0:
......@@ -681,7 +683,7 @@ def parse_page_core(
"""根据parse_mode,构造spans,主要是文本类的字符填充"""
if parse_mode == SupportedPdfParseMethod.TXT:
"""使用新版本的混合ocr方案"""
"""使用新版本的混合ocr方案."""
spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, lang)
elif parse_mode == SupportedPdfParseMethod.OCR:
......@@ -689,7 +691,6 @@ def parse_page_core(
else:
raise Exception('parse_mode must be txt or ocr')
"""先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks(
all_discarded_blocks, spans, 0.4
......@@ -762,8 +763,8 @@ def parse_page_core(
def pdf_parse_union(
dataset: Dataset,
model_list,
dataset: Dataset,
imageWriter,
parse_mode,
start_page_id=0,
......@@ -832,4 +833,4 @@ def pdf_parse_union(
if __name__ == '__main__':
pass
\ No newline at end of file
pass
import json
import os
......@@ -13,23 +12,76 @@ from magic_pdf.libs.json_compressor import JsonCompressor
class PipeResult:
def __init__(self, pipe_res, dataset: Dataset):
"""Initialized.
Args:
pipe_res (list[dict]): the pipeline processed result of model inference result
dataset (Dataset): the dataset associated with pipe_res
"""
self._pipe_res = pipe_res
self._dataset = dataset
def dump_md(self, writer: DataWriter, file_path: str, img_dir_or_bucket_prefix: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
def dump_md(
self,
writer: DataWriter,
file_path: str,
img_dir_or_bucket_prefix: str,
drop_mode=DropMode.WHOLE_PDF,
md_make_mode=MakeMode.MM_MD,
):
"""Dump The Markdown.
Args:
writer (DataWriter): File writer handle
file_path (str): The file location of markdown
img_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.WHOLE_PDF.
md_make_mode (str, optional): The content Type of Markdown be made. Defaults to MakeMode.MM_MD.
"""
pdf_info_list = self._pipe_res['pdf_info']
md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_dir_or_bucket_prefix)
md_content = union_make(
pdf_info_list, md_make_mode, drop_mode, img_dir_or_bucket_prefix
)
writer.write_string(file_path, md_content)
def dump_content_list(self, writer: DataWriter, file_path: str, image_dir_or_bucket_prefix: str, drop_mode=DropMode.NONE):
def dump_content_list(
self, writer: DataWriter, file_path: str, image_dir_or_bucket_prefix: str
):
"""Dump Content List.
Args:
writer (DataWriter): File writer handle
file_path (str): The file location of content list
image_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
"""
pdf_info_list = self._pipe_res['pdf_info']
content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, image_dir_or_bucket_prefix)
writer.write_string(file_path, json.dumps(content_list, ensure_ascii=False, indent=4))
content_list = union_make(
pdf_info_list,
MakeMode.STANDARD_FORMAT,
DropMode.NONE,
image_dir_or_bucket_prefix,
)
writer.write_string(
file_path, json.dumps(content_list, ensure_ascii=False, indent=4)
)
def dump_middle_json(self, writer: DataWriter, file_path: str):
writer.write_string(file_path, json.dumps(self._pipe_res, ensure_ascii=False, indent=4))
"""Dump the result of pipeline.
Args:
writer (DataWriter): File writer handler
file_path (str): The file location of middle json
"""
writer.write_string(
file_path, json.dumps(self._pipe_res, ensure_ascii=False, indent=4)
)
def draw_layout(self, file_path: str) -> None:
"""Draw the layout.
Args:
file_path (str): The file location of layout result file
"""
dir_name = os.path.dirname(file_path)
base_name = os.path.basename(file_path)
if not os.path.exists(dir_name):
......@@ -38,6 +90,11 @@ class PipeResult:
draw_layout_bbox(pdf_info, self._dataset.data_bits(), dir_name, base_name)
def draw_span(self, file_path: str):
"""Draw the Span.
Args:
file_path (str): The file location of span result file
"""
dir_name = os.path.dirname(file_path)
base_name = os.path.basename(file_path)
if not os.path.exists(dir_name):
......@@ -46,6 +103,11 @@ class PipeResult:
draw_span_bbox(pdf_info, self._dataset.data_bits(), dir_name, base_name)
def draw_line_sort(self, file_path: str):
"""Draw line sort.
Args:
file_path (str): The file location of line sort result file
"""
dir_name = os.path.dirname(file_path)
base_name = os.path.basename(file_path)
if not os.path.exists(dir_name):
......@@ -53,10 +115,10 @@ class PipeResult:
pdf_info = self._pipe_res['pdf_info']
draw_line_sort_bbox(pdf_info, self._dataset.data_bits(), dir_name, base_name)
def draw_content_list(self, writer: DataWriter, file_path: str, img_dir_or_bucket_prefix: str, drop_mode=DropMode.WHOLE_PDF):
pdf_info_list = self._pipe_res['pdf_info']
content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_dir_or_bucket_prefix)
writer.write_string(file_path, json.dumps(content_list, ensure_ascii=False, indent=4))
def get_compress_pdf_mid_data(self):
"""Compress the pipeline result.
Returns:
str: compress the pipeline result and return
"""
return JsonCompressor.compress_json(self.pdf_mid_data)
......@@ -10,7 +10,7 @@ from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import FileBasedDataWriter
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.model.types import InferenceResult
from magic_pdf.model.operators import InferenceResult
# from io import BytesIO
# from pypdf import PdfReader, PdfWriter
......@@ -223,8 +223,7 @@ def do_parse(
pipe_result.dump_content_list(
md_writer,
f'{pdf_file_name}_content_list.json',
image_dir,
drop_mode=DropMode.NONE,
image_dir
)
logger.info(f'local output dir is {local_md_dir}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment