Unverified Commit d94ddcf8 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1318 from icecraft/refactor/code_struture

refactor: refactor code
parents 303a4b01 b2887ca0
from typing import Callable
from abc import ABC, abstractmethod
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset
from magic_pdf.pipe.operators import PipeResult
__use_inside_model__ = True
__model_mode__ = "full"
class InferenceResultBase(ABC):
@abstractmethod
def __init__(self, inference_results: list, dataset: Dataset):
"""Initialized method.
Args:
inference_results (list): the inference result generated by model
dataset (Dataset): the dataset related with model inference result
"""
self._infer_res = inference_results
self._dataset = dataset
@abstractmethod
def draw_model(self, file_path: str) -> None:
"""Draw model inference result.
Args:
file_path (str): the output file path
"""
pass
@abstractmethod
def dump_model(self, writer: DataWriter, file_path: str):
"""Dump model inference result to file.
Args:
writer (DataWriter): writer handle
file_path (str): the location of target file
"""
pass
@abstractmethod
def get_infer_res(self):
"""Get the inference result.
Returns:
list: the inference result generated by model
"""
pass
@abstractmethod
def apply(self, proc: Callable, *args, **kwargs):
"""Apply callable method which.
Args:
proc (Callable): invoke proc as follows:
proc(inference_result, *args, **kwargs)
Returns:
Any: return the result generated by proc
"""
pass
@abstractmethod
def pipe_txt_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
"""Post-proc the model inference result, Extract the text using the
third library, such as `pymupdf`
Args:
imageWriter (DataWriter): the image writer handle
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
lang (str, optional): Defaults to None.
Returns:
PipeResult: the result
"""
pass
@abstractmethod
def pipe_ocr_mode(
self,
imageWriter: DataWriter,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
) -> PipeResult:
pass
...@@ -11,17 +11,12 @@ from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE ...@@ -11,17 +11,12 @@ from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
from magic_pdf.data.dataset import Dataset from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from magic_pdf.model.operators import InferenceResult
from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
clean_vram, clean_vram, crop_img, get_res_list_from_layout_res)
crop_img,
get_res_list_from_layout_res,
)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_adjusted_mfdetrec_res, get_ocr_result_list)
get_ocr_result_list, from magic_pdf.operators.models import InferenceResult
)
YOLO_LAYOUT_BASE_BATCH_SIZE = 4 YOLO_LAYOUT_BASE_BATCH_SIZE = 4
MFD_BASE_BATCH_SIZE = 1 MFD_BASE_BATCH_SIZE = 1
...@@ -50,7 +45,7 @@ class BatchAnalyze: ...@@ -50,7 +45,7 @@ class BatchAnalyze:
pil_img = Image.fromarray(image) pil_img = Image.fromarray(image)
width, height = pil_img.size width, height = pil_img.size
if height > width: if height > width:
input_res = {"poly": [0, 0, width, 0, width, height, 0, height]} input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
new_image, useful_list = crop_img( new_image, useful_list = crop_img(
input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0 input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
) )
...@@ -65,17 +60,17 @@ class BatchAnalyze: ...@@ -65,17 +60,17 @@ class BatchAnalyze:
for image_index, useful_list in modified_images: for image_index, useful_list in modified_images:
for res in images_layout_res[image_index]: for res in images_layout_res[image_index]:
for i in range(len(res["poly"])): for i in range(len(res['poly'])):
if i % 2 == 0: if i % 2 == 0:
res["poly"][i] = ( res['poly'][i] = (
res["poly"][i] - useful_list[0] + useful_list[2] res['poly'][i] - useful_list[0] + useful_list[2]
) )
else: else:
res["poly"][i] = ( res['poly'][i] = (
res["poly"][i] - useful_list[1] + useful_list[3] res['poly'][i] - useful_list[1] + useful_list[3]
) )
logger.info( logger.info(
f"layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}" f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
) )
if self.model.apply_formula: if self.model.apply_formula:
...@@ -85,7 +80,7 @@ class BatchAnalyze: ...@@ -85,7 +80,7 @@ class BatchAnalyze:
images, self.batch_ratio * MFD_BASE_BATCH_SIZE images, self.batch_ratio * MFD_BASE_BATCH_SIZE
) )
logger.info( logger.info(
f"mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}" f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
) )
# 公式识别 # 公式识别
...@@ -98,7 +93,7 @@ class BatchAnalyze: ...@@ -98,7 +93,7 @@ class BatchAnalyze:
for image_index in range(len(images)): for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index] images_layout_res[image_index] += images_formula_list[image_index]
logger.info( logger.info(
f"mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}" f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}'
) )
# 清理显存 # 清理显存
...@@ -156,7 +151,7 @@ class BatchAnalyze: ...@@ -156,7 +151,7 @@ class BatchAnalyze:
if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE: if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad(): with torch.no_grad():
table_result = self.model.table_model.predict( table_result = self.model.table_model.predict(
new_image, "html" new_image, 'html'
) )
if len(table_result) > 0: if len(table_result) > 0:
html_code = table_result[0] html_code = table_result[0]
...@@ -169,32 +164,32 @@ class BatchAnalyze: ...@@ -169,32 +164,32 @@ class BatchAnalyze:
run_time = time.time() - single_table_start_time run_time = time.time() - single_table_start_time
if run_time > self.model.table_max_time: if run_time > self.model.table_max_time:
logger.warning( logger.warning(
f"table recognition processing exceeds max time {self.model.table_max_time}s" f'table recognition processing exceeds max time {self.model.table_max_time}s'
) )
# 判断是否返回正常 # 判断是否返回正常
if html_code: if html_code:
expected_ending = html_code.strip().endswith( expected_ending = html_code.strip().endswith(
"</html>" '</html>'
) or html_code.strip().endswith("</table>") ) or html_code.strip().endswith('</table>')
if expected_ending: if expected_ending:
res["html"] = html_code res['html'] = html_code
else: else:
logger.warning( logger.warning(
"table recognition processing fails, not found expected HTML table end" 'table recognition processing fails, not found expected HTML table end'
) )
else: else:
logger.warning( logger.warning(
"table recognition processing fails, not get html return" 'table recognition processing fails, not get html return'
) )
table_time += time.time() - table_start table_time += time.time() - table_start
table_count += len(table_res_list) table_count += len(table_res_list)
if self.model.apply_ocr: if self.model.apply_ocr:
logger.info(f"ocr time: {round(ocr_time, 2)}, image num: {ocr_count}") logger.info(f'ocr time: {round(ocr_time, 2)}, image num: {ocr_count}')
else: else:
logger.info(f"det time: {round(ocr_time, 2)}, image num: {ocr_count}") logger.info(f'det time: {round(ocr_time, 2)}, image num: {ocr_count}')
if self.model.apply_table: if self.model.apply_table:
logger.info(f"table time: {round(table_time, 2)}, image num: {table_count}") logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
return images_layout_res return images_layout_res
...@@ -211,8 +206,7 @@ def doc_batch_analyze( ...@@ -211,8 +206,7 @@ def doc_batch_analyze(
table_enable=None, table_enable=None,
batch_ratio: int | None = None, batch_ratio: int | None = None,
) -> InferenceResult: ) -> InferenceResult:
""" """Perform batch analysis on a document dataset.
Perform batch analysis on a document dataset.
Args: Args:
dataset (Dataset): The dataset containing document pages to be analyzed. dataset (Dataset): The dataset containing document pages to be analyzed.
...@@ -234,9 +228,9 @@ def doc_batch_analyze( ...@@ -234,9 +228,9 @@ def doc_batch_analyze(
""" """
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise CUDA_NOT_AVAILABLE("batch analyze not support in CPU mode") raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
lang = None if lang == "" else lang lang = None if lang == '' else lang
# TODO: auto detect batch size # TODO: auto detect batch size
batch_ratio = 1 if batch_ratio is None else batch_ratio batch_ratio = 1 if batch_ratio is None else batch_ratio
end_page_id = end_page_id if end_page_id else len(dataset) end_page_id = end_page_id if end_page_id else len(dataset)
...@@ -255,26 +249,26 @@ def doc_batch_analyze( ...@@ -255,26 +249,26 @@ def doc_batch_analyze(
if start_page_id <= index <= end_page_id: if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index) page_data = dataset.get_page(index)
img_dict = page_data.get_image() img_dict = page_data.get_image()
images.append(img_dict["img"]) images.append(img_dict['img'])
analyze_result = batch_model(images) analyze_result = batch_model(images)
for index in range(len(dataset)): for index in range(len(dataset)):
page_data = dataset.get_page(index) page_data = dataset.get_page(index)
img_dict = page_data.get_image() img_dict = page_data.get_image()
page_width = img_dict["width"] page_width = img_dict['width']
page_height = img_dict["height"] page_height = img_dict['height']
if start_page_id <= index <= end_page_id: if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0) result = analyze_result.pop(0)
else: else:
result = [] result = []
page_info = {"page_no": index, "height": page_height, "width": page_width} page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {"layout_dets": result, "page_info": page_info} page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict) model_json.append(page_dict)
# TODO: clean memory when gpu memory is not enough # TODO: clean memory when gpu memory is not enough
clean_memory_start_time = time.time() clean_memory_start_time = time.time()
clean_memory() clean_memory()
logger.info(f"clean memory time: {round(time.time() - clean_memory_start_time, 2)}") logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
return InferenceResult(model_json, dataset) return InferenceResult(model_json, dataset)
import os import os
import time import time
from loguru import logger
# 关闭paddle的信号处理 # 关闭paddle的信号处理
import paddle import paddle
from loguru import logger
paddle.disable_signal_handler() paddle.disable_signal_handler()
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
...@@ -25,7 +25,7 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config, ...@@ -25,7 +25,7 @@ from magic_pdf.libs.config_reader import (get_device, get_formula_config,
get_local_models_dir, get_local_models_dir,
get_table_recog_config) get_table_recog_config)
from magic_pdf.model.model_list import MODEL from magic_pdf.model.model_list import MODEL
from magic_pdf.model.operators import InferenceResult from magic_pdf.operators.models import InferenceResult
def dict_compare(d1, d2): def dict_compare(d1, d2):
......
...@@ -7,15 +7,13 @@ from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT ...@@ -7,15 +7,13 @@ from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.data_reader_writer import DataWriter from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.data.dataset import Dataset from magic_pdf.data.dataset import Dataset
from magic_pdf.filter import classify
from magic_pdf.libs.draw_bbox import draw_model_bbox from magic_pdf.libs.draw_bbox import draw_model_bbox
from magic_pdf.libs.version import __version__ from magic_pdf.libs.version import __version__
from magic_pdf.model import InferenceResultBase from magic_pdf.operators.pipes import PipeResult
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
from magic_pdf.pipe.operators import PipeResult
class InferenceResult(InferenceResultBase): class InferenceResult:
def __init__(self, inference_results: list, dataset: Dataset): def __init__(self, inference_results: list, dataset: Dataset):
"""Initialized method. """Initialized method.
......
...@@ -10,7 +10,7 @@ from magic_pdf.config.make_content_config import DropMode, MakeMode ...@@ -10,7 +10,7 @@ from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import FileBasedDataWriter from magic_pdf.data.data_reader_writer import FileBasedDataWriter
from magic_pdf.data.dataset import PymuDocDataset from magic_pdf.data.dataset import PymuDocDataset
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.model.operators import InferenceResult from magic_pdf.operators.models import InferenceResult
# from io import BytesIO # from io import BytesIO
# from pypdf import PdfReader, PdfWriter # from pypdf import PdfReader, PdfWriter
...@@ -167,7 +167,7 @@ def do_parse( ...@@ -167,7 +167,7 @@ def do_parse(
logger.error('need model list input') logger.error('need model list input')
exit(2) exit(2)
else: else:
infer_result = InferenceResult(model_list, ds) infer_result = InferenceResult(model_list, ds)
if parse_method == 'ocr': if parse_method == 'ocr':
pipe_result = infer_result.pipe_ocr_mode( pipe_result = infer_result.pipe_ocr_mode(
...@@ -186,7 +186,7 @@ def do_parse( ...@@ -186,7 +186,7 @@ def do_parse(
pipe_result = infer_result.pipe_ocr_mode( pipe_result = infer_result.pipe_ocr_mode(
image_writer, debug_mode=True, lang=ds._lang image_writer, debug_mode=True, lang=ds._lang
) )
if f_draw_model_bbox: if f_draw_model_bbox:
infer_result.draw_model( infer_result.draw_model(
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
Model Api Model Api
========== ==========
.. autoclass:: magic_pdf.model.InferenceResultBase .. autoclass:: magic_pdf.operators.models.InferenceResult
:members: :members:
:inherited-members: :inherited-members:
:show-inheritance: :show-inheritance:
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
Pipeline Api Pipeline Api
============= =============
.. autoclass:: magic_pdf.pipe.operators.PipeResult .. autoclass:: magic_pdf.operators.pipes.PipeResult
:members: :members:
:inherited-members: :inherited-members:
:show-inheritance: :show-inheritance:
\ No newline at end of file
Inference Result Inference Result
================== ==================
.. admonition:: Tip .. admonition:: Tip
...@@ -7,7 +7,7 @@ Inference Result ...@@ -7,7 +7,7 @@ Inference Result
Please first navigate to :doc:`tutorial/pipeline` to get an initial understanding of how the pipeline works; this will help in understanding the content of this section. Please first navigate to :doc:`tutorial/pipeline` to get an initial understanding of how the pipeline works; this will help in understanding the content of this section.
The **InferenceResult** class is a container for storing model inference results and implements a series of methods related to these results, such as draw_model, dump_model. The **InferenceResult** class is a container for storing model inference results and implements a series of methods related to these results, such as draw_model, dump_model.
Checkout :doc:`../api/model_operators` for more details about **InferenceResult** Checkout :doc:`../api/model_operators` for more details about **InferenceResult**
...@@ -56,7 +56,7 @@ Structure Definition ...@@ -56,7 +56,7 @@ Structure Definition
page_info: PageInfo = Field(description="Page metadata") page_info: PageInfo = Field(description="Page metadata")
Example Example
^^^^^^^^^^^ ^^^^^^^^^^^
.. code:: json .. code:: json
...@@ -116,15 +116,15 @@ and bottom-left points respectively. |Poly Coordinate Diagram| ...@@ -116,15 +116,15 @@ and bottom-left points respectively. |Poly Coordinate Diagram|
Inference Result Inference Result
------------------------- -------------------------
.. code:: python .. code:: python
from magic_pdf.model.operators import InferenceResult from magic_pdf.operators.models import InferenceResult
from magic_pdf.data.dataset import Dataset from magic_pdf.data.dataset import Dataset
dataset : Dataset = some_data_set # not real dataset dataset : Dataset = some_data_set # not real dataset
# The inference results of all pages, ordered by page number, are stored in a list as the inference results of MinerU # The inference results of all pages, ordered by page number, are stored in a list as the inference results of MinerU
...@@ -142,4 +142,3 @@ some_model.pdf ...@@ -142,4 +142,3 @@ some_model.pdf
.. |Poly Coordinate Diagram| image:: ../_static/image/poly.png .. |Poly Coordinate Diagram| image:: ../_static/image/poly.png
Pipe Result Pipe Result
============== ==============
.. admonition:: Tip .. admonition:: Tip
...@@ -9,7 +9,7 @@ Pipe Result ...@@ -9,7 +9,7 @@ Pipe Result
Please first navigate to :doc:`tutorial/pipeline` to get an initial understanding of how the pipeline works; this will help in understanding the content of this section. Please first navigate to :doc:`tutorial/pipeline` to get an initial understanding of how the pipeline works; this will help in understanding the content of this section.
The **PipeResult** class is a container for storing pipeline processing results and implements a series of methods related to these results, such as draw_layout, draw_span. The **PipeResult** class is a container for storing pipeline processing results and implements a series of methods related to these results, such as draw_layout, draw_span.
Checkout :doc:`../api/pipe_operators` for more details about **PipeResult** Checkout :doc:`../api/pipe_operators` for more details about **PipeResult**
...@@ -288,14 +288,14 @@ example ...@@ -288,14 +288,14 @@ example
} }
Pipeline Result Pipeline Result
------------------ ------------------
.. code:: python .. code:: python
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
from magic_pdf.pipe.operators import PipeResult from magic_pdf.operators.pipes import PipeResult
from magic_pdf.data.dataset import Dataset from magic_pdf.data.dataset import Dataset
res = pdf_parse_union(*args, **kwargs) res = pdf_parse_union(*args, **kwargs)
res['_parse_type'] = PARSE_TYPE_OCR res['_parse_type'] = PARSE_TYPE_OCR
...@@ -332,4 +332,4 @@ unrecognized inline formulas. ...@@ -332,4 +332,4 @@ unrecognized inline formulas.
.. figure:: ../_static/image/spans_example.png .. figure:: ../_static/image/spans_example.png
:alt: spans example :alt: spans example
spans example spans example
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment