Unverified Commit d94ddcf8 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1318 from icecraft/refactor/code_struture

refactor: refactor code
parents 303a4b01 b2887ca0
from typing import Callable
from abc import ABC, abstractmethod
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.pipe.operators import PipeResult
__use_inside_model__ = True
__model_mode__ = "full"
class InferenceResultBase(ABC):
@abstractmethod
def __init__(self, inference_results: list, dataset: Dataset):
"""Initialized method.
Args:
inference_results (list): the inference result generated by model
dataset (Dataset): the dataset related with model inference result
"""
self._infer_res = inference_results
self._dataset = dataset
@abstractmethod
def draw_model(self, file_path: str) -> None:
"""Draw model inference result.
Args:
file_path (str): the output file path
"""
pass
@abstractmethod
def dump_model(self, writer: DataWriter, file_path: str):
"""Dump model inference result to file.
Args:
writer (DataWriter): writer handle
file_path (str): the location of target file
"""
pass
@abstractmethod
def get_infer_res(self):
"""Get the inference result.
Returns:
list: the inference result generated by model
"""
pass
@abstractmethod
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(inference_result, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
pass
@abstractmethod
def pipe_txt_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result, Extract the text using the
third library, such as `pymupdf`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pass
@abstractmethod
def pipe_ocr_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
pass
......@@ -11,17 +11,12 @@ from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from magic_pdf.model.operators import InferenceResult
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_utils import (
clean_vram,
crop_img,
get_res_list_from_layout_res,
)
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res,
get_ocr_result_list,
)
get_adjusted_mfdetrec_res, get_ocr_result_list)
from magic_pdf.operators.models import InferenceResult
YOLO_LAYOUT_BASE_BATCH_SIZE = 4
MFD_BASE_BATCH_SIZE = 1
......@@ -50,7 +45,7 @@ class BatchAnalyze:
pil_img = Image.fromarray(image)
width, height = pil_img.size
if height > width:
input_res = {"poly": [0, 0, width, 0, width, height, 0, height]}
input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
new_image, useful_list = crop_img(
input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
)
......@@ -65,17 +60,17 @@ class BatchAnalyze:
for image_index, useful_list in modified_images:
for res in images_layout_res[image_index]:
for i in range(len(res["poly"])):
for i in range(len(res['poly'])):
if i % 2 == 0:
res["poly"][i] = (
res["poly"][i] - useful_list[0] + useful_list[2]
res['poly'][i] = (
res['poly'][i] - useful_list[0] + useful_list[2]
)
else:
res["poly"][i] = (
res["poly"][i] - useful_list[1] + useful_list[3]
res['poly'][i] = (
res['poly'][i] - useful_list[1] + useful_list[3]
)
logger.info(
f"layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}"
f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
)
if self.model.apply_formula:
......@@ -85,7 +80,7 @@ class BatchAnalyze:
images, self.batch_ratio * MFD_BASE_BATCH_SIZE
)
logger.info(
f"mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}"
f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
)
# 公式识别
......@@ -98,7 +93,7 @@ class BatchAnalyze:
for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index]
logger.info(
f"mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}"
f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}'
)
# 清理显存
......@@ -156,7 +151,7 @@ class BatchAnalyze:
if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad():
table_result = self.model.table_model.predict(
new_image, "html"
new_image, 'html'
)
if len(table_result) > 0:
html_code = table_result[0]
......@@ -169,32 +164,32 @@ class BatchAnalyze:
run_time = time.time() - single_table_start_time
if run_time > self.model.table_max_time:
logger.warning(
f"table recognition processing exceeds max time {self.model.table_max_time}s"
f'table recognition processing exceeds max time {self.model.table_max_time}s'
)
# 判断是否返回正常
if html_code:
expected_ending = html_code.strip().endswith(
"</html>"
) or html_code.strip().endswith("</table>")
'</html>'
) or html_code.strip().endswith('</table>')
if expected_ending:
res["html"] = html_code
res['html'] = html_code
else:
logger.warning(
"table recognition processing fails, not found expected HTML table end"
'table recognition processing fails, not found expected HTML table end'
)
else:
logger.warning(
"table recognition processing fails, not get html return"
'table recognition processing fails, not get html return'
)
table_time += time.time() - table_start
table_count += len(table_res_list)
if self.model.apply_ocr:
logger.info(f"ocr time: {round(ocr_time, 2)}, image num: {ocr_count}")
logger.info(f'ocr time: {round(ocr_time, 2)}, image num: {ocr_count}')
else:
logger.info(f"det time: {round(ocr_time, 2)}, image num: {ocr_count}")
logger.info(f'det time: {round(ocr_time, 2)}, image num: {ocr_count}')
if self.model.apply_table:
logger.info(f"table time: {round(table_time, 2)}, image num: {table_count}")
logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
return images_layout_res
......@@ -211,8 +206,7 @@ def doc_batch_analyze(
table_enable=None,
batch_ratio: int | None = None,
) -> InferenceResult:
"""
Perform batch analysis on a document dataset.
"""Perform batch analysis on a document dataset.
Args:
dataset (Dataset): The dataset containing document pages to be analyzed.
......@@ -234,9 +228,9 @@ def doc_batch_analyze(
"""
if not torch.cuda.is_available():
raise CUDA_NOT_AVAILABLE("batch analyze not support in CPU mode")
raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
lang = None if lang == "" else lang
lang = None if lang == '' else lang
# TODO: auto detect batch size
batch_ratio = 1 if batch_ratio is None else batch_ratio
end_page_id = end_page_id if end_page_id else len(dataset)
......@@ -255,26 +249,26 @@ def doc_batch_analyze(
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict["img"])
images.append(img_dict['img'])
analyze_result = batch_model(images)
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
page_width = img_dict["width"]
page_height = img_dict["height"]
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0)
else:
result = []
page_info = {"page_no": index, "height": page_height, "width": page_width}
page_dict = {"layout_dets": result, "page_info": page_info}
page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
# TODO: clean memory when gpu memory is not enough
clean_memory_start_time = time.time()
clean_memory()
logger.info(f"clean memory time: {round(time.time() - clean_memory_start_time, 2)}")
logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
return InferenceResult(model_json, dataset)
import os
import time
from loguru import logger
# 关闭paddle的信号处理
import paddle
from loguru import logger
paddle.disable_signal_handler()
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
......@@ -25,7 +25,7 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_local_models_dir,
get_table_recog_config)
from magic_pdf.model.model_list import MODEL
from magic_pdf.model.operators import InferenceResult
from magic_pdf.operators.models import InferenceResult
def dict_compare(d1, d2):
......
......@@ -7,15 +7,13 @@ from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.filter import classify
from magic_pdf.libs.draw_bbox import draw_model_bbox
from magic_pdf.libs.version import __version__
from magic_pdf.model import InferenceResultBase
from magic_pdf.operators.pipes import PipeResult
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
from magic_pdf.pipe.operators import PipeResult
class InferenceResult(InferenceResultBase):
class InferenceResult:
def __init__(self, inference_results: list, dataset: Dataset):
"""Initialized method.
......
......@@ -10,7 +10,7 @@ from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import FileBasedDataWriter
from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.model.operators import InferenceResult
from magic_pdf.operators.models import InferenceResult
# from io import BytesIO
# from pypdf import PdfReader, PdfWriter
......
......@@ -2,7 +2,7 @@
Model Api
==========
.. autoclass:: magic_pdf.model.InferenceResultBase
.. autoclass:: magic_pdf.operators.models.InferenceResult
:members:
:inherited-members:
:show-inheritance:
......@@ -3,7 +3,7 @@
Pipeline Api
=============
.. autoclass:: magic_pdf.pipe.operators.PipeResult
.. autoclass:: magic_pdf.operators.pipes.PipeResult
:members:
:inherited-members:
:show-inheritance:
......@@ -122,7 +122,7 @@ Inference Result
.. code:: python
from magic_pdf.model.operators import InferenceResult
from magic_pdf.operators.models import InferenceResult
from magic_pdf.data.dataset import Dataset
dataset : Dataset = some_data_set # not real dataset
......@@ -142,4 +142,3 @@ some_model.pdf
.. |Poly Coordinate Diagram| image:: ../_static/image/poly.png
......@@ -294,7 +294,7 @@ Pipeline Result
.. code:: python
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
from magic_pdf.pipe.operators import PipeResult
from magic_pdf.operators.pipes import PipeResult
from magic_pdf.data.dataset import Dataset
res = pdf_parse_union(*args, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment