"vscode:/vscode.git/clone" did not exist on "6c69853afd447ec33b146f57dc3b28999c8537ec"
Unverified Commit 158e556b authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1063 from opendatalab/release-0.10.0

Release 0.10.0
parents 038f48d3 30be5017
COLOR_BG_HEADER_TXT_BLOCK = "color_background_header_txt_block"
PAGE_NO = "page-no" # 页码
CONTENT_IN_FOOT_OR_HEADER = 'in-foot-header-area' # 页眉页脚内的文本
VERTICAL_TEXT = 'vertical-text' # 垂直文本
ROTATE_TEXT = 'rotate-text' # 旋转文本
EMPTY_SIDE_BLOCK = 'empty-side-block' # 边缘上的空白没有任何内容的block
ON_IMAGE_TEXT = 'on-image-text' # 文本在图片上
ON_TABLE_TEXT = 'on-table-text' # 文本在表格上
class DropTag:
PAGE_NUMBER = "page_no"
HEADER = "header"
FOOTER = "footer"
FOOTNOTE = "footnote"
NOT_IN_LAYOUT = "not_in_layout"
SPAN_OVERLAP = "span_overlap"
BLOCK_OVERLAP = "block_overlap"
from io import BytesIO
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter import cv2
from magic_pdf.libs.commons import fitz import numpy as np
from magic_pdf.libs.commons import join_path from PIL import Image
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.hash_utils import compute_sha256 from magic_pdf.libs.hash_utils import compute_sha256
def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter: AbsReaderWriter): def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter: DataWriter):
""" """从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 save_path:需要同时支持s3和本地,
从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 图片存放在save_path下,文件名是:
save_path:需要同时支持s3和本地, 图片存放在save_path下,文件名是: {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。 {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。"""
"""
# 拼接文件名 # 拼接文件名
filename = f"{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}" filename = f'{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}'
# 老版本返回不带bucket的路径 # 老版本返回不带bucket的路径
img_path = join_path(return_path, filename) if return_path is not None else None img_path = join_path(return_path, filename) if return_path is not None else None
# 新版本生成平铺路径 # 新版本生成平铺路径
img_hash256_path = f"{compute_sha256(img_path)}.jpg" img_hash256_path = f'{compute_sha256(img_path)}.jpg'
# 将坐标转换为fitz.Rect对象 # 将坐标转换为fitz.Rect对象
rect = fitz.Rect(*bbox) rect = fitz.Rect(*bbox)
...@@ -28,6 +29,29 @@ def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWri ...@@ -28,6 +29,29 @@ def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWri
byte_data = pix.tobytes(output='jpeg', jpg_quality=95) byte_data = pix.tobytes(output='jpeg', jpg_quality=95)
imageWriter.write(byte_data, img_hash256_path, AbsReaderWriter.MODE_BIN) imageWriter.write(img_hash256_path, byte_data)
return img_hash256_path return img_hash256_path
def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
# 将坐标转换为fitz.Rect对象
rect = fitz.Rect(*bbox)
# 配置缩放倍数为3倍
zoom = fitz.Matrix(3, 3)
# 截取图片
pix = page.get_pixmap(clip=rect, matrix=zoom)
# 将字节数据转换为文件对象
image_file = BytesIO(pix.tobytes(output='png'))
# 使用 Pillow 打开图像
pil_image = Image.open(image_file)
if mode == "cv2":
image_result = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2BGR)
elif mode == "pillow":
image_result = pil_image
else:
raise ValueError(f"mode: {mode} is not supported.")
return image_result
\ No newline at end of file
...@@ -163,7 +163,9 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, ...@@ -163,7 +163,9 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
page_width = img_dict["width"] page_width = img_dict["width"]
page_height = img_dict["height"] page_height = img_dict["height"]
if start_page_id <= index <= end_page_id: if start_page_id <= index <= end_page_id:
page_start = time.time()
result = custom_model(img) result = custom_model(img)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
else: else:
result = [] result = []
page_info = {"page_no": index, "height": page_height, "width": page_width} page_info = {"page_no": index, "height": page_height, "width": page_width}
......
import enum import enum
import json import json
from magic_pdf.config.model_block_type import ModelBlockTypeEnum
from magic_pdf.config.ocr_content_type import CategoryId, ContentType
from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
FileBasedDataWriter)
from magic_pdf.data.dataset import Dataset from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance, from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
bbox_relative_pos, box_area, calculate_iou, bbox_relative_pos, box_area, calculate_iou,
...@@ -9,11 +13,7 @@ from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance, ...@@ -9,11 +13,7 @@ from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
from magic_pdf.libs.commons import fitz, join_path from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.coordinate_transform import get_scale_ratio from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.local_math import float_gt from magic_pdf.libs.local_math import float_gt
from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
CAPATION_OVERLAP_AREA_RATIO = 0.6 CAPATION_OVERLAP_AREA_RATIO = 0.6
MERGE_BOX_OVERLAP_AREA_RATIO = 1.1 MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
...@@ -1050,27 +1050,27 @@ class MagicModel: ...@@ -1050,27 +1050,27 @@ class MagicModel:
if __name__ == '__main__': if __name__ == '__main__':
drw = DiskReaderWriter(r'D:/project/20231108code-clean') drw = FileBasedDataReader(r'D:/project/20231108code-clean')
if 0: if 0:
pdf_file_path = r'linshixuqiu\19983-00.pdf' pdf_file_path = r'linshixuqiu\19983-00.pdf'
model_file_path = r'linshixuqiu\19983-00_new.json' model_file_path = r'linshixuqiu\19983-00_new.json'
pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN) pdf_bytes = drw.read(pdf_file_path)
model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT) model_json_txt = drw.read(model_file_path).decode()
model_list = json.loads(model_json_txt) model_list = json.loads(model_json_txt)
write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00' write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00'
img_bucket_path = 'imgs' img_bucket_path = 'imgs'
img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path)) img_writer = FileBasedDataWriter(join_path(write_path, img_bucket_path))
pdf_docs = fitz.open('pdf', pdf_bytes) pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs) magic_model = MagicModel(model_list, pdf_docs)
if 1: if 1:
from magic_pdf.data.dataset import PymuDocDataset
model_list = json.loads( model_list = json.loads(
drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.json') drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.json')
) )
pdf_bytes = drw.read( pdf_bytes = drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf')
'/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf', AbsReaderWriter.MODE_BIN
) magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs)
for i in range(7): for i in range(7):
print(magic_model.get_imgs(i)) print(magic_model.get_imgs(i))
import numpy as np # flake8: noqa
import torch
from loguru import logger
import os import os
import time import time
import cv2 import cv2
import numpy as np
import torch
import yaml import yaml
from loguru import logger
from PIL import Image from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
...@@ -13,16 +15,18 @@ os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger ...@@ -13,16 +15,18 @@ os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try: try:
import torchtext import torchtext
if torchtext.__version__ >= "0.18.0": if torchtext.__version__ >= '0.18.0':
torchtext.disable_torchtext_deprecation_warning() torchtext.disable_torchtext_deprecation_warning()
except ImportError: except ImportError:
pass pass
from magic_pdf.libs.Constants import * from magic_pdf.config.constants import *
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram from magic_pdf.model.sub_modules.model_utils import (
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
class CustomPEKModel: class CustomPEKModel:
...@@ -41,42 +45,54 @@ class CustomPEKModel: ...@@ -41,42 +45,54 @@ class CustomPEKModel:
model_config_dir = os.path.join(root_dir, 'resources', 'model_config') model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
# 构建 model_configs.yaml 文件的完整路径 # 构建 model_configs.yaml 文件的完整路径
config_path = os.path.join(model_config_dir, 'model_configs.yaml') config_path = os.path.join(model_config_dir, 'model_configs.yaml')
with open(config_path, "r", encoding='utf-8') as f: with open(config_path, 'r', encoding='utf-8') as f:
self.configs = yaml.load(f, Loader=yaml.FullLoader) self.configs = yaml.load(f, Loader=yaml.FullLoader)
# 初始化解析配置 # 初始化解析配置
# layout config # layout config
self.layout_config = kwargs.get("layout_config") self.layout_config = kwargs.get('layout_config')
self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO) self.layout_model_name = self.layout_config.get(
'model', MODEL_NAME.DocLayout_YOLO
)
# formula config # formula config
self.formula_config = kwargs.get("formula_config") self.formula_config = kwargs.get('formula_config')
self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD) self.mfd_model_name = self.formula_config.get(
self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small) 'mfd_model', MODEL_NAME.YOLO_V8_MFD
self.apply_formula = self.formula_config.get("enable", True) )
self.mfr_model_name = self.formula_config.get(
'mfr_model', MODEL_NAME.UniMerNet_v2_Small
)
self.apply_formula = self.formula_config.get('enable', True)
# table config # table config
self.table_config = kwargs.get("table_config") self.table_config = kwargs.get('table_config')
self.apply_table = self.table_config.get("enable", False) self.apply_table = self.table_config.get('enable', False)
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE) self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE) self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
# ocr config # ocr config
self.apply_ocr = ocr self.apply_ocr = ocr
self.lang = kwargs.get("lang", None) self.lang = kwargs.get('lang', None)
logger.info( logger.info(
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, " 'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
"apply_table: {}, table_model: {}, lang: {}".format( 'apply_table: {}, table_model: {}, lang: {}'.format(
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.layout_model_name,
self.lang self.apply_formula,
self.apply_ocr,
self.apply_table,
self.table_model_name,
self.lang,
) )
) )
# 初始化解析方案 # 初始化解析方案
self.device = kwargs.get("device", "cpu") self.device = kwargs.get('device', 'cpu')
logger.info("using device: {}".format(self.device)) logger.info('using device: {}'.format(self.device))
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models")) models_dir = kwargs.get(
logger.info("using models_dir: {}".format(models_dir)) 'models_dir', os.path.join(root_dir, 'resources', 'models')
)
logger.info('using models_dir: {}'.format(models_dir))
atom_model_manager = AtomModelSingleton() atom_model_manager = AtomModelSingleton()
...@@ -85,18 +101,24 @@ class CustomPEKModel: ...@@ -85,18 +101,24 @@ class CustomPEKModel:
# 初始化公式检测模型 # 初始化公式检测模型
self.mfd_model = atom_model_manager.get_atom_model( self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD, atom_model_name=AtomicModel.MFD,
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])), mfd_weights=str(
device=self.device os.path.join(
models_dir, self.configs['weights'][self.mfd_model_name]
)
),
device=self.device,
) )
# 初始化公式解析模型 # 初始化公式解析模型
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name])) mfr_weight_dir = str(
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml")) os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
)
mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
self.mfr_model = atom_model_manager.get_atom_model( self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR, atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir, mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path, mfr_cfg_path=mfr_cfg_path,
device=self.device device=self.device,
) )
# 初始化layout模型 # 初始化layout模型
...@@ -104,19 +126,30 @@ class CustomPEKModel: ...@@ -104,19 +126,30 @@ class CustomPEKModel:
self.layout_model = atom_model_manager.get_atom_model( self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout, atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.LAYOUTLMv3, layout_model_name=MODEL_NAME.LAYOUTLMv3,
layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])), layout_weights=str(
layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")), os.path.join(
device=self.device models_dir, self.configs['weights'][self.layout_model_name]
)
),
layout_config_file=str(
os.path.join(
model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
)
),
device=self.device,
) )
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
self.layout_model = atom_model_manager.get_atom_model( self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout, atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.DocLayout_YOLO, layout_model_name=MODEL_NAME.DocLayout_YOLO,
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])), doclayout_yolo_weights=str(
device=self.device os.path.join(
models_dir, self.configs['weights'][self.layout_model_name]
)
),
device=self.device,
) )
# 初始化ocr # 初始化ocr
if self.apply_ocr:
self.ocr_model = atom_model_manager.get_atom_model( self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR, atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log, ocr_show_log=show_log,
...@@ -125,21 +158,19 @@ class CustomPEKModel: ...@@ -125,21 +158,19 @@ class CustomPEKModel:
) )
# init table model # init table model
if self.apply_table: if self.apply_table:
table_model_dir = self.configs["weights"][self.table_model_name] table_model_dir = self.configs['weights'][self.table_model_name]
self.table_model = atom_model_manager.get_atom_model( self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table, atom_model_name=AtomicModel.Table,
table_model_name=self.table_model_name, table_model_name=self.table_model_name,
table_model_path=str(os.path.join(models_dir, table_model_dir)), table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time, table_max_time=self.table_max_time,
device=self.device device=self.device,
) )
logger.info('DocAnalysis init done!') logger.info('DocAnalysis init done!')
def __call__(self, image): def __call__(self, image):
page_start = time.time()
# layout检测 # layout检测
layout_start = time.time() layout_start = time.time()
layout_res = [] layout_res = []
...@@ -150,7 +181,7 @@ class CustomPEKModel: ...@@ -150,7 +181,7 @@ class CustomPEKModel:
# doclayout_yolo # doclayout_yolo
layout_res = self.layout_model.predict(image) layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2) layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection time: {layout_cost}") logger.info(f'layout detection time: {layout_cost}')
pil_img = Image.fromarray(image) pil_img = Image.fromarray(image)
...@@ -158,23 +189,24 @@ class CustomPEKModel: ...@@ -158,23 +189,24 @@ class CustomPEKModel:
# 公式检测 # 公式检测
mfd_start = time.time() mfd_start = time.time()
mfd_res = self.mfd_model.predict(image) mfd_res = self.mfd_model.predict(image)
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}") logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
# 公式识别 # 公式识别
mfr_start = time.time() mfr_start = time.time()
formula_list = self.mfr_model.predict(mfd_res, image) formula_list = self.mfr_model.predict(mfd_res, image)
layout_res.extend(formula_list) layout_res.extend(formula_list)
mfr_cost = round(time.time() - mfr_start, 2) mfr_cost = round(time.time() - mfr_start, 2)
logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}") logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
# 清理显存 # 清理显存
clean_vram(self.device, vram_threshold=8) clean_vram(self.device, vram_threshold=8)
# 从layout_res中获取ocr区域、表格区域、公式区域 # 从layout_res中获取ocr区域、表格区域、公式区域
ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res) ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
)
# ocr识别 # ocr识别
if self.apply_ocr:
ocr_start = time.time() ocr_start = time.time()
# Process each area that requires OCR processing # Process each area that requires OCR processing
for res in ocr_res_list: for res in ocr_res_list:
...@@ -183,7 +215,10 @@ class CustomPEKModel: ...@@ -183,7 +215,10 @@ class CustomPEKModel:
# OCR recognition # OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
if self.apply_ocr:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0] ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
else:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
# Integration results # Integration results
if ocr_res: if ocr_res:
...@@ -191,7 +226,10 @@ class CustomPEKModel: ...@@ -191,7 +226,10 @@ class CustomPEKModel:
layout_res.extend(ocr_result_list) layout_res.extend(ocr_result_list)
ocr_cost = round(time.time() - ocr_start, 2) ocr_cost = round(time.time() - ocr_start, 2)
if self.apply_ocr:
logger.info(f"ocr time: {ocr_cost}") logger.info(f"ocr time: {ocr_cost}")
else:
logger.info(f"det time: {ocr_cost}")
# 表格识别 table recognition # 表格识别 table recognition
if self.apply_table: if self.apply_table:
...@@ -202,27 +240,35 @@ class CustomPEKModel: ...@@ -202,27 +240,35 @@ class CustomPEKModel:
html_code = None html_code = None
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE: if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad(): with torch.no_grad():
table_result = self.table_model.predict(new_image, "html") table_result = self.table_model.predict(new_image, 'html')
if len(table_result) > 0: if len(table_result) > 0:
html_code = table_result[0] html_code = table_result[0]
elif self.table_model_name == MODEL_NAME.TABLE_MASTER: elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image) html_code = self.table_model.img2html(new_image)
elif self.table_model_name == MODEL_NAME.RAPID_TABLE: elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image) html_code, table_cell_bboxes, elapse = self.table_model.predict(
new_image
)
run_time = time.time() - single_table_start_time run_time = time.time() - single_table_start_time
if run_time > self.table_max_time: if run_time > self.table_max_time:
logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s") logger.warning(
f'table recognition processing exceeds max time {self.table_max_time}s'
)
# 判断是否返回正常 # 判断是否返回正常
if html_code: if html_code:
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>') expected_ending = html_code.strip().endswith(
'</html>'
) or html_code.strip().endswith('</table>')
if expected_ending: if expected_ending:
res["html"] = html_code res['html'] = html_code
else: else:
logger.warning(f"table recognition processing fails, not found expected HTML table end") logger.warning(
'table recognition processing fails, not found expected HTML table end'
)
else: else:
logger.warning(f"table recognition processing fails, not get html return") logger.warning(
logger.info(f"table time: {round(time.time() - table_start, 2)}") 'table recognition processing fails, not get html return'
)
logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----") logger.info(f'table time: {round(time.time() - table_start, 2)}')
return layout_res return layout_res
from loguru import logger from loguru import logger
from magic_pdf.libs.Constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
Layoutlmv3_Predictor
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
RapidTableModel
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel StructTableModel
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
...@@ -19,14 +24,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): ...@@ -19,14 +24,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time) table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER: elif table_model_type == MODEL_NAME.TABLE_MASTER:
config = { config = {
"model_dir": model_path, 'model_dir': model_path,
"device": _device_ 'device': _device_
} }
table_model = TableMasterPaddleModel(config) table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE: elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTableModel() table_model = RapidTableModel()
else: else:
logger.error("table model type not allow") logger.error('table model type not allow')
exit(1) exit(1)
return table_model return table_model
...@@ -58,7 +63,7 @@ def ocr_model_init(show_log: bool = False, ...@@ -58,7 +63,7 @@ def ocr_model_init(show_log: bool = False,
use_dilation=True, use_dilation=True,
det_db_unclip_ratio=1.8, det_db_unclip_ratio=1.8,
): ):
if lang is not None: if lang is not None and lang != '':
model = ModifiedPaddleOCR( model = ModifiedPaddleOCR(
show_log=show_log, show_log=show_log,
det_db_box_thresh=det_db_box_thresh, det_db_box_thresh=det_db_box_thresh,
...@@ -87,8 +92,8 @@ class AtomModelSingleton: ...@@ -87,8 +92,8 @@ class AtomModelSingleton:
return cls._instance return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs): def get_atom_model(self, atom_model_name: str, **kwargs):
lang = kwargs.get("lang", None) lang = kwargs.get('lang', None)
layout_model_name = kwargs.get("layout_model_name", None) layout_model_name = kwargs.get('layout_model_name', None)
key = (atom_model_name, layout_model_name, lang) key = (atom_model_name, layout_model_name, lang)
if key not in self._models: if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs) self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
...@@ -98,47 +103,47 @@ class AtomModelSingleton: ...@@ -98,47 +103,47 @@ class AtomModelSingleton:
def atom_model_init(model_name: str, **kwargs): def atom_model_init(model_name: str, **kwargs):
atom_model = None atom_model = None
if model_name == AtomicModel.Layout: if model_name == AtomicModel.Layout:
if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3: if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
atom_model = layout_model_init( atom_model = layout_model_init(
kwargs.get("layout_weights"), kwargs.get('layout_weights'),
kwargs.get("layout_config_file"), kwargs.get('layout_config_file'),
kwargs.get("device") kwargs.get('device')
) )
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO: elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
atom_model = doclayout_yolo_model_init( atom_model = doclayout_yolo_model_init(
kwargs.get("doclayout_yolo_weights"), kwargs.get('doclayout_yolo_weights'),
kwargs.get("device") kwargs.get('device')
) )
elif model_name == AtomicModel.MFD: elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init( atom_model = mfd_model_init(
kwargs.get("mfd_weights"), kwargs.get('mfd_weights'),
kwargs.get("device") kwargs.get('device')
) )
elif model_name == AtomicModel.MFR: elif model_name == AtomicModel.MFR:
atom_model = mfr_model_init( atom_model = mfr_model_init(
kwargs.get("mfr_weight_dir"), kwargs.get('mfr_weight_dir'),
kwargs.get("mfr_cfg_path"), kwargs.get('mfr_cfg_path'),
kwargs.get("device") kwargs.get('device')
) )
elif model_name == AtomicModel.OCR: elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init( atom_model = ocr_model_init(
kwargs.get("ocr_show_log"), kwargs.get('ocr_show_log'),
kwargs.get("det_db_box_thresh"), kwargs.get('det_db_box_thresh'),
kwargs.get("lang") kwargs.get('lang')
) )
elif model_name == AtomicModel.Table: elif model_name == AtomicModel.Table:
atom_model = table_model_init( atom_model = table_model_init(
kwargs.get("table_model_name"), kwargs.get('table_model_name'),
kwargs.get("table_model_path"), kwargs.get('table_model_path'),
kwargs.get("table_max_time"), kwargs.get('table_max_time'),
kwargs.get("device") kwargs.get('device')
) )
else: else:
logger.error("model name not allow") logger.error('model name not allow')
exit(1) exit(1)
if atom_model is None: if atom_model is None:
logger.error("model init failed") logger.error('model init failed')
exit(1) exit(1)
else: else:
return atom_model return atom_model
...@@ -71,7 +71,13 @@ def remove_intervals(original, masks): ...@@ -71,7 +71,13 @@ def remove_intervals(original, masks):
def update_det_boxes(dt_boxes, mfd_res): def update_det_boxes(dt_boxes, mfd_res):
new_dt_boxes = [] new_dt_boxes = []
angle_boxes_list = []
for text_box in dt_boxes: for text_box in dt_boxes:
if calculate_is_angle(text_box):
angle_boxes_list.append(text_box)
continue
text_bbox = points_to_bbox(text_box) text_bbox = points_to_bbox(text_box)
masks_list = [] masks_list = []
for mf_box in mfd_res: for mf_box in mfd_res:
...@@ -85,6 +91,9 @@ def update_det_boxes(dt_boxes, mfd_res): ...@@ -85,6 +91,9 @@ def update_det_boxes(dt_boxes, mfd_res):
temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]])) temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
if len(temp_dt_box) > 0: if len(temp_dt_box) > 0:
new_dt_boxes.extend(temp_dt_box) new_dt_boxes.extend(temp_dt_box)
new_dt_boxes.extend(angle_boxes_list)
return new_dt_boxes return new_dt_boxes
...@@ -143,9 +152,11 @@ def merge_det_boxes(dt_boxes): ...@@ -143,9 +152,11 @@ def merge_det_boxes(dt_boxes):
angle_boxes_list = [] angle_boxes_list = []
for text_box in dt_boxes: for text_box in dt_boxes:
text_bbox = points_to_bbox(text_box) text_bbox = points_to_bbox(text_box)
if text_bbox[2] <= text_bbox[0] or text_bbox[3] <= text_bbox[1]:
if calculate_is_angle(text_box):
angle_boxes_list.append(text_box) angle_boxes_list.append(text_box)
continue continue
text_box_dict = { text_box_dict = {
'bbox': text_bbox, 'bbox': text_bbox,
'type': 'text', 'type': 'text',
...@@ -200,15 +211,21 @@ def get_ocr_result_list(ocr_res, useful_list): ...@@ -200,15 +211,21 @@ def get_ocr_result_list(ocr_res, useful_list):
ocr_result_list = [] ocr_result_list = []
for box_ocr_res in ocr_res: for box_ocr_res in ocr_res:
if len(box_ocr_res) == 2:
p1, p2, p3, p4 = box_ocr_res[0] p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1] text, score = box_ocr_res[1]
average_angle_degrees = calculate_angle_degrees(box_ocr_res[0]) else:
if average_angle_degrees > 0.5: p1, p2, p3, p4 = box_ocr_res
text, score = "", 1
# average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
# if average_angle_degrees > 0.5:
poly = [p1, p2, p3, p4]
if calculate_is_angle(poly):
# logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}") # logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
# 与x轴的夹角超过0.5度,对边界做一下矫正 # 与x轴的夹角超过0.5度,对边界做一下矫正
# 计算几何中心 # 计算几何中心
x_center = sum(point[0] for point in box_ocr_res[0]) / 4 x_center = sum(point[0] for point in poly) / 4
y_center = sum(point[1] for point in box_ocr_res[0]) / 4 y_center = sum(point[1] for point in poly) / 4
new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2 new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
new_width = p3[0] - p1[0] new_width = p3[0] - p1[0]
p1 = [x_center - new_width / 2, y_center - new_height / 2] p1 = [x_center - new_width / 2, y_center - new_height / 2]
...@@ -257,3 +274,12 @@ def calculate_angle_degrees(poly): ...@@ -257,3 +274,12 @@ def calculate_angle_degrees(poly):
# logger.info(f"average_angle_degrees: {average_angle_degrees}") # logger.info(f"average_angle_degrees: {average_angle_degrees}")
return average_angle_degrees return average_angle_degrees
def calculate_is_angle(poly):
p1, p2, p3, p4 = poly
height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
if 0.8 * height <= (p3[1] - p1[1]) <= 1.2 * height:
return False
else:
# logger.info((p3[1] - p1[1])/height)
return True
\ No newline at end of file
...@@ -78,9 +78,18 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -78,9 +78,18 @@ class ModifiedPaddleOCR(PaddleOCR):
for idx, img in enumerate(imgs): for idx, img in enumerate(imgs):
img = preprocess_image(img) img = preprocess_image(img)
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
if not dt_boxes: if dt_boxes is None:
ocr_res.append(None) ocr_res.append(None)
continue continue
dt_boxes = sorted_boxes(dt_boxes)
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
dt_boxes = merge_det_boxes(dt_boxes)
if mfd_res:
bef = time.time()
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
aft = time.time()
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
len(dt_boxes), aft - bef))
tmp_res = [box.tolist() for box in dt_boxes] tmp_res = [box.tolist() for box in dt_boxes]
ocr_res.append(tmp_res) ocr_res.append(tmp_res)
return ocr_res return ocr_res
...@@ -125,9 +134,8 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -125,9 +134,8 @@ class ModifiedPaddleOCR(PaddleOCR):
dt_boxes = sorted_boxes(dt_boxes) dt_boxes = sorted_boxes(dt_boxes)
# @todo 目前是在bbox层merge,对倾斜文本行的兼容性不佳,需要修改成支持poly的merge # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
# dt_boxes = merge_det_boxes(dt_boxes) dt_boxes = merge_det_boxes(dt_boxes)
if mfd_res: if mfd_res:
bef = time.time() bef = time.time()
......
...@@ -10,5 +10,7 @@ class RapidTableModel(object): ...@@ -10,5 +10,7 @@ class RapidTableModel(object):
def predict(self, image): def predict(self, image):
ocr_result, _ = self.ocr_engine(np.asarray(image)) ocr_result, _ = self.ocr_engine(np.asarray(image))
if ocr_result is None:
return None, None, None
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result) html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
return html_code, table_cell_bboxes, elapse return html_code, table_cell_bboxes, elapse
\ No newline at end of file
import os
import cv2 import cv2
import numpy as np
from paddleocr.ppstructure.table.predict_table import TableSystem from paddleocr.ppstructure.table.predict_table import TableSystem
from paddleocr.ppstructure.utility import init_args from paddleocr.ppstructure.utility import init_args
from magic_pdf.libs.Constants import *
import os
from PIL import Image from PIL import Image
import numpy as np
from magic_pdf.config.constants import * # noqa: F403
class TableMasterPaddleModel(object): class TableMasterPaddleModel(object):
""" """This class is responsible for converting image of table into HTML format
This class is responsible for converting image of table into HTML format using a pre-trained model. using a pre-trained model.
Attributes: Attributes:
- table_sys: An instance of TableSystem initialized with parsed arguments. - table_sys: An instance of TableSystem initialized with parsed arguments.
...@@ -40,30 +42,30 @@ class TableMasterPaddleModel(object): ...@@ -40,30 +42,30 @@ class TableMasterPaddleModel(object):
image = np.asarray(image) image = np.asarray(image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
pred_res, _ = self.table_sys(image) pred_res, _ = self.table_sys(image)
pred_html = pred_res["html"] pred_html = pred_res['html']
# res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace( # res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace(
# "</table></body></html>","") + "</table></td>\n" # "</table></body></html>","") + "</table></td>\n"
return pred_html return pred_html
def parse_args(self, **kwargs): def parse_args(self, **kwargs):
parser = init_args() parser = init_args()
model_dir = kwargs.get("model_dir") model_dir = kwargs.get('model_dir')
table_model_dir = os.path.join(model_dir, TABLE_MASTER_DIR) table_model_dir = os.path.join(model_dir, TABLE_MASTER_DIR) # noqa: F405
table_char_dict_path = os.path.join(model_dir, TABLE_MASTER_DICT) table_char_dict_path = os.path.join(model_dir, TABLE_MASTER_DICT) # noqa: F405
det_model_dir = os.path.join(model_dir, DETECT_MODEL_DIR) det_model_dir = os.path.join(model_dir, DETECT_MODEL_DIR) # noqa: F405
rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR) rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR) # noqa: F405
rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT) rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT) # noqa: F405
device = kwargs.get("device", "cpu") device = kwargs.get('device', 'cpu')
use_gpu = True if device.startswith("cuda") else False use_gpu = True if device.startswith('cuda') else False
config = { config = {
"use_gpu": use_gpu, 'use_gpu': use_gpu,
"table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN), 'table_max_len': kwargs.get('table_max_len', TABLE_MAX_LEN), # noqa: F405
"table_algorithm": "TableMaster", 'table_algorithm': 'TableMaster',
"table_model_dir": table_model_dir, 'table_model_dir': table_model_dir,
"table_char_dict_path": table_char_dict_path, 'table_char_dict_path': table_char_dict_path,
"det_model_dir": det_model_dir, 'det_model_dir': det_model_dir,
"rec_model_dir": rec_model_dir, 'rec_model_dir': rec_model_dir,
"rec_char_dict_path": rec_char_dict_path, 'rec_char_dict_path': rec_char_dict_path,
} }
parser.set_defaults(**config) parser.set_defaults(**config)
return parser.parse_args([]) return parser.parse_args([])
import os
import json
from magic_pdf.para.commons import *
from magic_pdf.para.raw_processor import RawBlockProcessor
from magic_pdf.para.layout_match_processor import LayoutFilterProcessor
from magic_pdf.para.stats import BlockStatisticsCalculator
from magic_pdf.para.stats import DocStatisticsCalculator
from magic_pdf.para.title_processor import TitleProcessor
from magic_pdf.para.block_termination_processor import BlockTerminationProcessor
from magic_pdf.para.block_continuation_processor import BlockContinuationProcessor
from magic_pdf.para.draw import DrawAnnos
from magic_pdf.para.exceptions import (
DenseSingleLineBlockException,
TitleDetectionException,
TitleLevelException,
ParaSplitException,
ParaMergeException,
DiscardByException,
)
if sys.version_info[0] >= 3:
sys.stdout.reconfigure(encoding="utf-8") # type: ignore
class ParaProcessPipeline:
def __init__(self) -> None:
pass
def para_process_pipeline(self, pdf_info_dict, para_debug_mode=None, input_pdf_path=None, output_pdf_path=None):
"""
This function processes the paragraphs, including:
1. Read raw input json file into pdf_dic
2. Detect and replace equations
3. Combine spans into a natural line
4. Check if the paragraphs are inside bboxes passed from "layout_bboxes" key
5. Compute statistics for each block
6. Detect titles in the document
7. Detect paragraphs inside each block
8. Divide the level of the titles
9. Detect and combine paragraphs from different blocks into one paragraph
10. Check whether the final results after checking headings, dividing paragraphs within blocks, and merging paragraphs between blocks are plausible and reasonable.
11. Draw annotations on the pdf file
Parameters
----------
pdf_dic_json_fpath : str
path to the pdf dictionary json file.
Notice: data noises, including overlap blocks, header, footer, watermark, vertical margin note have been removed already.
input_pdf_doc : str
path to the input pdf file
output_pdf_path : str
path to the output pdf file
Returns
-------
pdf_dict : dict
result dictionary
"""
error_info = None
output_json_file = ""
output_dir = ""
if input_pdf_path is not None:
input_pdf_path = os.path.abspath(input_pdf_path)
# print_green_on_red(f">>>>>>>>>>>>>>>>>>> Process the paragraphs of {input_pdf_path}")
if output_pdf_path is not None:
output_dir = os.path.dirname(output_pdf_path)
output_json_file = f"{output_dir}/pdf_dic.json"
def __save_pdf_dic(pdf_dic, output_pdf_path, stage="0", para_debug_mode=para_debug_mode):
"""
Save the pdf_dic to a json file
"""
output_pdf_file_name = os.path.basename(output_pdf_path)
# output_dir = os.path.dirname(output_pdf_path)
output_dir = "\\tmp\\pdf_parse"
output_pdf_file_name = output_pdf_file_name.replace(".pdf", f"_stage_{stage}.json")
pdf_dic_json_fpath = os.path.join(output_dir, output_pdf_file_name)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if para_debug_mode == "full":
with open(pdf_dic_json_fpath, "w", encoding="utf-8") as f:
json.dump(pdf_dic, f, indent=2, ensure_ascii=False)
# Validate the output already exists
if not os.path.exists(pdf_dic_json_fpath):
print_red(f"Failed to save the pdf_dic to {pdf_dic_json_fpath}")
return None
else:
print_green(f"Succeed to save the pdf_dic to {pdf_dic_json_fpath}")
return pdf_dic_json_fpath
"""
Preprocess the lines of block
"""
# Find and replace the interline and inline equations, should be better done before the paragraph processing
# Create "para_blocks" for each page.
# equationProcessor = EquationsProcessor()
# pdf_dic = equationProcessor.batch_process_blocks(pdf_info_dict)
# Combine spans into a natural line
rawBlockProcessor = RawBlockProcessor()
pdf_dic = rawBlockProcessor.batch_process_blocks(pdf_info_dict)
# print(f"pdf_dic['page_0']['para_blocks'][0]: {pdf_dic['page_0']['para_blocks'][0]}", end="\n\n")
# Check if the paragraphs are inside bboxes passed from "layout_bboxes" key
layoutFilter = LayoutFilterProcessor()
pdf_dic = layoutFilter.batch_process_blocks(pdf_dic)
# Compute statistics for each block
blockStatisticsCalculator = BlockStatisticsCalculator()
pdf_dic = blockStatisticsCalculator.batch_process_blocks(pdf_dic)
# print(f"pdf_dic['page_0']['para_blocks'][0]: {pdf_dic['page_0']['para_blocks'][0]}", end="\n\n")
# Compute statistics for all blocks(namely this pdf document)
docStatisticsCalculator = DocStatisticsCalculator()
pdf_dic = docStatisticsCalculator.calc_stats_of_doc(pdf_dic)
# print(f"pdf_dic['statistics']: {pdf_dic['statistics']}", end="\n\n")
# Dump the first three stages of pdf_dic to a json file
if para_debug_mode == "full":
pdf_dic_json_fpath = __save_pdf_dic(pdf_dic, output_pdf_path, stage="0", para_debug_mode=para_debug_mode)
"""
Detect titles in the document
"""
doc_statistics = pdf_dic["statistics"]
titleProcessor = TitleProcessor(doc_statistics)
pdf_dic = titleProcessor.batch_process_blocks_detect_titles(pdf_dic)
if para_debug_mode == "full":
pdf_dic_json_fpath = __save_pdf_dic(pdf_dic, output_pdf_path, stage="1", para_debug_mode=para_debug_mode)
"""
Detect and divide the level of the titles
"""
titleProcessor = TitleProcessor()
pdf_dic = titleProcessor.batch_process_blocks_recog_title_level(pdf_dic)
if para_debug_mode == "full":
pdf_dic_json_fpath = __save_pdf_dic(pdf_dic, output_pdf_path, stage="2", para_debug_mode=para_debug_mode)
"""
Detect and split paragraphs inside each block
"""
blockInnerParasProcessor = BlockTerminationProcessor()
pdf_dic = blockInnerParasProcessor.batch_process_blocks(pdf_dic)
if para_debug_mode == "full":
pdf_dic_json_fpath = __save_pdf_dic(pdf_dic, output_pdf_path, stage="3", para_debug_mode=para_debug_mode)
# pdf_dic_json_fpath = __save_pdf_dic(pdf_dic, output_pdf_path, stage="3", para_debug_mode="full")
# print_green(f"pdf_dic_json_fpath: {pdf_dic_json_fpath}")
"""
Detect and combine paragraphs from different blocks into one paragraph
"""
blockContinuationProcessor = BlockContinuationProcessor()
pdf_dic = blockContinuationProcessor.batch_tag_paras(pdf_dic)
pdf_dic = blockContinuationProcessor.batch_merge_paras(pdf_dic)
if para_debug_mode == "full":
pdf_dic_json_fpath = __save_pdf_dic(pdf_dic, output_pdf_path, stage="4", para_debug_mode=para_debug_mode)
# pdf_dic_json_fpath = __save_pdf_dic(pdf_dic, output_pdf_path, stage="4", para_debug_mode="full")
# print_green(f"pdf_dic_json_fpath: {pdf_dic_json_fpath}")
"""
Discard pdf files by checking exceptions and return the error info to the caller
"""
discardByException = DiscardByException()
is_discard_by_single_line_block = discardByException.discard_by_single_line_block(
pdf_dic, exception=DenseSingleLineBlockException()
)
is_discard_by_title_detection = discardByException.discard_by_title_detection(
pdf_dic, exception=TitleDetectionException()
)
is_discard_by_title_level = discardByException.discard_by_title_level(pdf_dic, exception=TitleLevelException())
is_discard_by_split_para = discardByException.discard_by_split_para(pdf_dic, exception=ParaSplitException())
is_discard_by_merge_para = discardByException.discard_by_merge_para(pdf_dic, exception=ParaMergeException())
"""
if any(
info is not None
for info in [
is_discard_by_single_line_block,
is_discard_by_title_detection,
is_discard_by_title_level,
is_discard_by_split_para,
is_discard_by_merge_para,
]
):
error_info = next(
(
info
for info in [
is_discard_by_single_line_block,
is_discard_by_title_detection,
is_discard_by_title_level,
is_discard_by_split_para,
is_discard_by_merge_para,
]
if info is not None
),
None,
)
return pdf_dic, error_info
if any(
info is not None
for info in [
is_discard_by_single_line_block,
is_discard_by_title_detection,
is_discard_by_title_level,
is_discard_by_split_para,
is_discard_by_merge_para,
]
):
error_info = next(
(
info
for info in [
is_discard_by_single_line_block,
is_discard_by_title_detection,
is_discard_by_title_level,
is_discard_by_split_para,
is_discard_by_merge_para,
]
if info is not None
),
None,
)
return pdf_dic, error_info
"""
"""
Dump the final pdf_dic to a json file
"""
if para_debug_mode is not None:
with open(output_json_file, "w", encoding="utf-8") as f:
json.dump(pdf_info_dict, f, ensure_ascii=False, indent=4)
"""
Draw the annotations
"""
if is_discard_by_single_line_block is not None:
error_info = is_discard_by_single_line_block
elif is_discard_by_title_detection is not None:
error_info = is_discard_by_title_detection
elif is_discard_by_title_level is not None:
error_info = is_discard_by_title_level
elif is_discard_by_split_para is not None:
error_info = is_discard_by_split_para
elif is_discard_by_merge_para is not None:
error_info = is_discard_by_merge_para
if error_info is not None:
return pdf_dic, error_info
"""
Dump the final pdf_dic to a json file
"""
if para_debug_mode is not None:
with open(output_json_file, "w", encoding="utf-8") as f:
json.dump(pdf_info_dict, f, ensure_ascii=False, indent=4)
"""
Draw the annotations
"""
if para_debug_mode is not None:
drawAnnos = DrawAnnos()
drawAnnos.draw_annos(input_pdf_path, pdf_dic, output_pdf_path)
"""
Remove the intermediate files which are generated in the process of paragraph processing if debug_mode is simple
"""
if para_debug_mode is not None:
for fpath in os.listdir(output_dir):
if fpath.endswith(".json") and "stage" in fpath:
os.remove(os.path.join(output_dir, fpath))
return pdf_dic, error_info
This diff is collapsed.
This diff is collapsed.
import copy import copy
from loguru import logger from magic_pdf.config.constants import CROSS_PAGE, LINES_DELETED
from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.libs.Constants import LINES_DELETED, CROSS_PAGE
from magic_pdf.libs.ocr_content_type import BlockType, ContentType LINE_STOP_FLAG = (
'.',
LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';') '!',
'?',
'。',
'!',
'?',
')',
')',
'"',
'”',
':',
':',
';',
';',
)
LIST_END_FLAG = ('.', '。', ';', ';') LIST_END_FLAG = ('.', '。', ';', ';')
class ListLineTag: class ListLineTag:
IS_LIST_START_LINE = "is_list_start_line" IS_LIST_START_LINE = 'is_list_start_line'
IS_LIST_END_LINE = "is_list_end_line" IS_LIST_END_LINE = 'is_list_end_line'
def __process_blocks(blocks): def __process_blocks(blocks):
...@@ -27,12 +40,14 @@ def __process_blocks(blocks): ...@@ -27,12 +40,14 @@ def __process_blocks(blocks):
# 如果当前块是 text 类型 # 如果当前块是 text 类型
if current_block['type'] == 'text': if current_block['type'] == 'text':
current_block["bbox_fs"] = copy.deepcopy(current_block["bbox"]) current_block['bbox_fs'] = copy.deepcopy(current_block['bbox'])
if 'lines' in current_block and len(current_block["lines"]) > 0: if 'lines' in current_block and len(current_block['lines']) > 0:
current_block['bbox_fs'] = [min([line['bbox'][0] for line in current_block['lines']]), current_block['bbox_fs'] = [
min([line['bbox'][0] for line in current_block['lines']]),
min([line['bbox'][1] for line in current_block['lines']]), min([line['bbox'][1] for line in current_block['lines']]),
max([line['bbox'][2] for line in current_block['lines']]), max([line['bbox'][2] for line in current_block['lines']]),
max([line['bbox'][3] for line in current_block['lines']])] max([line['bbox'][3] for line in current_block['lines']]),
]
current_group.append(current_block) current_group.append(current_block)
# 检查下一个块是否存在 # 检查下一个块是否存在
...@@ -64,6 +79,7 @@ def __is_list_or_index_block(block): ...@@ -64,6 +79,7 @@ def __is_list_or_index_block(block):
line_height = first_line['bbox'][3] - first_line['bbox'][1] line_height = first_line['bbox'][3] - first_line['bbox'][1]
block_weight = block['bbox_fs'][2] - block['bbox_fs'][0] block_weight = block['bbox_fs'][2] - block['bbox_fs'][0]
block_height = block['bbox_fs'][3] - block['bbox_fs'][1] block_height = block['bbox_fs'][3] - block['bbox_fs'][1]
page_weight, page_height = block['page_size']
left_close_num = 0 left_close_num = 0
left_not_close_num = 0 left_not_close_num = 0
...@@ -75,10 +91,17 @@ def __is_list_or_index_block(block): ...@@ -75,10 +91,17 @@ def __is_list_or_index_block(block):
multiple_para_flag = False multiple_para_flag = False
last_line = block['lines'][-1] last_line = block['lines'][-1]
if page_weight == 0:
block_weight_radio = 0
else:
block_weight_radio = block_weight / page_weight
# logger.info(f"block_weight_radio: {block_weight_radio}")
# 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格) # 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格)
if (first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2 and if (
abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2 and first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2
block['bbox_fs'][2] - last_line['bbox'][2] > line_height and abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2
and block['bbox_fs'][2] - last_line['bbox'][2] > line_height
): ):
multiple_para_flag = True multiple_para_flag = True
...@@ -86,14 +109,14 @@ def __is_list_or_index_block(block): ...@@ -86,14 +109,14 @@ def __is_list_or_index_block(block):
line_mid_x = (line['bbox'][0] + line['bbox'][2]) / 2 line_mid_x = (line['bbox'][0] + line['bbox'][2]) / 2
block_mid_x = (block['bbox_fs'][0] + block['bbox_fs'][2]) / 2 block_mid_x = (block['bbox_fs'][0] + block['bbox_fs'][2]) / 2
if ( if (
line['bbox'][0] - block['bbox_fs'][0] > 0.8 * line_height and line['bbox'][0] - block['bbox_fs'][0] > 0.8 * line_height
block['bbox_fs'][2] - line['bbox'][2] > 0.8 * line_height and block['bbox_fs'][2] - line['bbox'][2] > 0.8 * line_height
): ):
external_sides_not_close_num += 1 external_sides_not_close_num += 1
if abs(line_mid_x - block_mid_x) < line_height / 2: if abs(line_mid_x - block_mid_x) < line_height / 2:
center_close_num += 1 center_close_num += 1
line_text = "" line_text = ''
for span in line['spans']: for span in line['spans']:
span_type = span['type'] span_type = span['type']
...@@ -114,7 +137,12 @@ def __is_list_or_index_block(block): ...@@ -114,7 +137,12 @@ def __is_list_or_index_block(block):
right_close_num += 1 right_close_num += 1
else: else:
# 右侧不顶格情况下是否有一段距离,拍脑袋用0.3block宽度做阈值 # 右侧不顶格情况下是否有一段距离,拍脑袋用0.3block宽度做阈值
# block宽的阈值可以小些,block窄的阈值要大
if block_weight_radio >= 0.5:
closed_area = 0.26 * block_weight closed_area = 0.26 * block_weight
else:
closed_area = 0.36 * block_weight
if block['bbox_fs'][2] - line['bbox'][2] > closed_area: if block['bbox_fs'][2] - line['bbox'][2] > closed_area:
right_not_close_num += 1 right_not_close_num += 1
...@@ -136,15 +164,19 @@ def __is_list_or_index_block(block): ...@@ -136,15 +164,19 @@ def __is_list_or_index_block(block):
if line_text[-1].isdigit(): if line_text[-1].isdigit():
num_end_count += 1 num_end_count += 1
if num_start_count / len(lines_text_list) >= 0.8 or num_end_count / len(lines_text_list) >= 0.8: if (
num_start_count / len(lines_text_list) >= 0.8
or num_end_count / len(lines_text_list) >= 0.8
):
line_num_flag = True line_num_flag = True
if flag_end_count / len(lines_text_list) >= 0.8: if flag_end_count / len(lines_text_list) >= 0.8:
line_end_flag = True line_end_flag = True
# 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index # 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index
if ((left_close_num / len(block['lines']) >= 0.8 or right_close_num / len(block['lines']) >= 0.8) if (
and line_num_flag left_close_num / len(block['lines']) >= 0.8
): or right_close_num / len(block['lines']) >= 0.8
) and line_num_flag:
for line in block['lines']: for line in block['lines']:
line[ListLineTag.IS_LIST_START_LINE] = True line[ListLineTag.IS_LIST_START_LINE] = True
return BlockType.Index return BlockType.Index
...@@ -152,17 +184,21 @@ def __is_list_or_index_block(block): ...@@ -152,17 +184,21 @@ def __is_list_or_index_block(block):
# 全部line都居中的特殊list识别,每行都需要换行,特征是多行,且大多数行都前后not_close,每line中点x坐标接近 # 全部line都居中的特殊list识别,每行都需要换行,特征是多行,且大多数行都前后not_close,每line中点x坐标接近
# 补充条件block的长宽比有要求 # 补充条件block的长宽比有要求
elif ( elif (
external_sides_not_close_num >= 2 and external_sides_not_close_num >= 2
center_close_num == len(block['lines']) and and center_close_num == len(block['lines'])
external_sides_not_close_num / len(block['lines']) >= 0.5 and and external_sides_not_close_num / len(block['lines']) >= 0.5
block_height / block_weight > 0.4 and block_height / block_weight > 0.4
): ):
for line in block['lines']: for line in block['lines']:
line[ListLineTag.IS_LIST_START_LINE] = True line[ListLineTag.IS_LIST_START_LINE] = True
return BlockType.List return BlockType.List
elif left_close_num >= 2 and ( elif (
right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2) and not multiple_para_flag: left_close_num >= 2
and (right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2)
and not multiple_para_flag
# and block_weight_radio > 0.27
):
# 处理一种特殊的没有缩进的list,所有行都贴左边,通过右边的空隙判断是否是item尾 # 处理一种特殊的没有缩进的list,所有行都贴左边,通过右边的空隙判断是否是item尾
if left_close_num / len(block['lines']) > 0.8: if left_close_num / len(block['lines']) > 0.8:
# 这种是每个item只有一行,且左边都贴边的短item list # 这种是每个item只有一行,且左边都贴边的短item list
...@@ -173,10 +209,15 @@ def __is_list_or_index_block(block): ...@@ -173,10 +209,15 @@ def __is_list_or_index_block(block):
# 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item # 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item
elif line_end_flag: elif line_end_flag:
for i, line in enumerate(block['lines']): for i, line in enumerate(block['lines']):
if len(lines_text_list[i]) > 0 and lines_text_list[i][-1] in LIST_END_FLAG: if (
len(lines_text_list[i]) > 0
and lines_text_list[i][-1] in LIST_END_FLAG
):
line[ListLineTag.IS_LIST_END_LINE] = True line[ListLineTag.IS_LIST_END_LINE] = True
if i + 1 < len(block['lines']): if i + 1 < len(block['lines']):
block['lines'][i + 1][ListLineTag.IS_LIST_START_LINE] = True block['lines'][i + 1][
ListLineTag.IS_LIST_START_LINE
] = True
# line item基本没有结束标识符,而且也没有缩进,按右侧空隙判断哪些是item end # line item基本没有结束标识符,而且也没有缩进,按右侧空隙判断哪些是item end
else: else:
line_start_flag = False line_start_flag = False
...@@ -185,7 +226,10 @@ def __is_list_or_index_block(block): ...@@ -185,7 +226,10 @@ def __is_list_or_index_block(block):
line[ListLineTag.IS_LIST_START_LINE] = True line[ListLineTag.IS_LIST_START_LINE] = True
line_start_flag = False line_start_flag = False
if abs(block['bbox_fs'][2] - line['bbox'][2]) > 0.1 * block_weight: if (
abs(block['bbox_fs'][2] - line['bbox'][2])
> 0.1 * block_weight
):
line[ListLineTag.IS_LIST_END_LINE] = True line[ListLineTag.IS_LIST_END_LINE] = True
line_start_flag = True line_start_flag = True
# 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_FLAG 结尾且数量和start line 一致 # 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_FLAG 结尾且数量和start line 一致
...@@ -223,10 +267,17 @@ def __merge_2_text_blocks(block1, block2): ...@@ -223,10 +267,17 @@ def __merge_2_text_blocks(block1, block2):
if len(last_line['spans']) > 0: if len(last_line['spans']) > 0:
last_span = last_line['spans'][-1] last_span = last_line['spans'][-1]
line_height = last_line['bbox'][3] - last_line['bbox'][1] line_height = last_line['bbox'][3] - last_line['bbox'][1]
if (abs(block2['bbox_fs'][2] - last_line['bbox'][2]) < line_height and if len(first_line['spans']) > 0:
not last_span['content'].endswith(LINE_STOP_FLAG) and first_span = first_line['spans'][0]
if len(first_span['content']) > 0:
span_start_with_num = first_span['content'][0].isdigit()
if (
abs(block2['bbox_fs'][2] - last_line['bbox'][2])
< line_height
and not last_span['content'].endswith(LINE_STOP_FLAG)
# 两个block宽度差距超过2倍也不合并 # 两个block宽度差距超过2倍也不合并
abs(block1_weight - block2_weight) < min_block_weight and abs(block1_weight - block2_weight) < min_block_weight
and not span_start_with_num
): ):
if block1['page_num'] != block2['page_num']: if block1['page_num'] != block2['page_num']:
for line in block1['lines']: for line in block1['lines']:
...@@ -263,7 +314,6 @@ def __is_list_group(text_blocks_group): ...@@ -263,7 +314,6 @@ def __is_list_group(text_blocks_group):
def __para_merge_page(blocks): def __para_merge_page(blocks):
page_text_blocks_groups = __process_blocks(blocks) page_text_blocks_groups = __process_blocks(blocks)
for text_blocks_group in page_text_blocks_groups: for text_blocks_group in page_text_blocks_groups:
if len(text_blocks_group) > 0: if len(text_blocks_group) > 0:
# 需要先在合并前对所有block判断是否为list or index block # 需要先在合并前对所有block判断是否为list or index block
for block in text_blocks_group: for block in text_blocks_group:
...@@ -272,7 +322,6 @@ def __para_merge_page(blocks): ...@@ -272,7 +322,6 @@ def __para_merge_page(blocks):
# logger.info(f"{block['type']}:{block}") # logger.info(f"{block['type']}:{block}")
if len(text_blocks_group) > 1: if len(text_blocks_group) > 1:
# 在合并前判断这个group 是否是一个 list group # 在合并前判断这个group 是否是一个 list group
is_list_group = __is_list_group(text_blocks_group) is_list_group = __is_list_group(text_blocks_group)
...@@ -284,11 +333,18 @@ def __para_merge_page(blocks): ...@@ -284,11 +333,18 @@ def __para_merge_page(blocks):
if i - 1 >= 0: if i - 1 >= 0:
prev_block = text_blocks_group[i - 1] prev_block = text_blocks_group[i - 1]
if current_block['type'] == 'text' and prev_block['type'] == 'text' and not is_list_group: if (
current_block['type'] == 'text'
and prev_block['type'] == 'text'
and not is_list_group
):
__merge_2_text_blocks(current_block, prev_block) __merge_2_text_blocks(current_block, prev_block)
elif ( elif (
(current_block['type'] == BlockType.List and prev_block['type'] == BlockType.List) or current_block['type'] == BlockType.List
(current_block['type'] == BlockType.Index and prev_block['type'] == BlockType.Index) and prev_block['type'] == BlockType.List
) or (
current_block['type'] == BlockType.Index
and prev_block['type'] == BlockType.Index
): ):
__merge_2_list_blocks(current_block, prev_block) __merge_2_list_blocks(current_block, prev_block)
...@@ -296,12 +352,13 @@ def __para_merge_page(blocks): ...@@ -296,12 +352,13 @@ def __para_merge_page(blocks):
continue continue
def para_split(pdf_info_dict, debug_mode=False): def para_split(pdf_info_dict):
all_blocks = [] all_blocks = []
for page_num, page in pdf_info_dict.items(): for page_num, page in pdf_info_dict.items():
blocks = copy.deepcopy(page['preproc_blocks']) blocks = copy.deepcopy(page['preproc_blocks'])
for block in blocks: for block in blocks:
block['page_num'] = page_num block['page_num'] = page_num
block['page_size'] = page['page_size']
all_blocks.extend(blocks) all_blocks.extend(blocks)
__para_merge_page(all_blocks) __para_merge_page(all_blocks)
...@@ -317,4 +374,4 @@ if __name__ == '__main__': ...@@ -317,4 +374,4 @@ if __name__ == '__main__':
# 调用函数 # 调用函数
groups = __process_blocks(input_blocks) groups = __process_blocks(input_blocks)
for group_index, group in enumerate(groups): for group_index, group in enumerate(groups):
print(f"Group {group_index}: {group}") print(f'Group {group_index}: {group}')
...@@ -9,6 +9,7 @@ def parse_pdf_by_ocr(pdf_bytes, ...@@ -9,6 +9,7 @@ def parse_pdf_by_ocr(pdf_bytes,
start_page_id=0, start_page_id=0,
end_page_id=None, end_page_id=None,
debug_mode=False, debug_mode=False,
lang=None,
): ):
dataset = PymuDocDataset(pdf_bytes) dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset, return pdf_parse_union(dataset,
...@@ -18,4 +19,5 @@ def parse_pdf_by_ocr(pdf_bytes, ...@@ -18,4 +19,5 @@ def parse_pdf_by_ocr(pdf_bytes,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
debug_mode=debug_mode, debug_mode=debug_mode,
lang=lang,
) )
...@@ -10,6 +10,7 @@ def parse_pdf_by_txt( ...@@ -10,6 +10,7 @@ def parse_pdf_by_txt(
start_page_id=0, start_page_id=0,
end_page_id=None, end_page_id=None,
debug_mode=False, debug_mode=False,
lang=None,
): ):
dataset = PymuDocDataset(pdf_bytes) dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset, return pdf_parse_union(dataset,
...@@ -19,4 +20,5 @@ def parse_pdf_by_txt( ...@@ -19,4 +20,5 @@ def parse_pdf_by_txt(
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id, end_page_id=end_page_id,
debug_mode=debug_mode, debug_mode=debug_mode,
lang=lang,
) )
...@@ -2,38 +2,47 @@ import time ...@@ -2,38 +2,47 @@ import time
from loguru import logger from loguru import logger
from magic_pdf.config.drop_reason import DropReason
from magic_pdf.config.ocr_content_type import ContentType
from magic_pdf.layout.layout_sort import (LAYOUT_UNPROC, get_bboxes_layout,
get_columns_cnt_of_layout)
from magic_pdf.libs.commons import fitz, get_delta_time from magic_pdf.libs.commons import fitz, get_delta_time
from magic_pdf.layout.layout_sort import get_bboxes_layout, LAYOUT_UNPROC, get_columns_cnt_of_layout
from magic_pdf.libs.convert_utils import dict_to_list from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.hash_utils import compute_md5 from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.local_math import float_equal from magic_pdf.libs.local_math import float_equal
from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.model.magic_model import MagicModel from magic_pdf.model.magic_model import MagicModel
from magic_pdf.para.para_split_v2 import para_split from magic_pdf.para.para_split_v2 import para_split
from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2 from magic_pdf.pre_proc.construct_page_dict import \
ocr_construct_page_component_v2
from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
from magic_pdf.pre_proc.equations_replace import remove_chars_in_text_blocks, replace_equations_in_textblock, \ from magic_pdf.pre_proc.equations_replace import (
combine_chars_to_pymudict combine_chars_to_pymudict, remove_chars_in_text_blocks,
from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split replace_equations_in_textblock)
from magic_pdf.pre_proc.ocr_dict_merge import sort_blocks_by_layout, fill_spans_in_blocks, fix_block_spans, \ from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
fix_discarded_block ocr_prepare_bboxes_for_layout_split
from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_min_spans, get_qa_need_list_v2, \ from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
remove_overlaps_low_confidence_spans fix_block_spans,
from magic_pdf.pre_proc.resolve_bbox_conflict import check_useful_block_horizontal_overlap fix_discarded_block,
sort_blocks_by_layout)
from magic_pdf.pre_proc.ocr_span_list_modify import (
get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
remove_overlaps_min_spans)
from magic_pdf.pre_proc.resolve_bbox_conflict import \
check_useful_block_horizontal_overlap
def remove_horizontal_overlap_block_which_smaller(all_bboxes): def remove_horizontal_overlap_block_which_smaller(all_bboxes):
useful_blocks = [] useful_blocks = []
for bbox in all_bboxes: for bbox in all_bboxes:
useful_blocks.append({ useful_blocks.append({'bbox': bbox[:4]})
"bbox": bbox[:4] is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = (
}) check_useful_block_horizontal_overlap(useful_blocks)
is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = check_useful_block_horizontal_overlap(useful_blocks) )
if is_useful_block_horz_overlap: if is_useful_block_horz_overlap:
logger.warning( logger.warning(
f"skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}") f'skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}'
)
for bbox in all_bboxes.copy(): for bbox in all_bboxes.copy():
if smaller_bbox == bbox[:4]: if smaller_bbox == bbox[:4]:
all_bboxes.remove(bbox) all_bboxes.remove(bbox)
...@@ -41,9 +50,9 @@ def remove_horizontal_overlap_block_which_smaller(all_bboxes): ...@@ -41,9 +50,9 @@ def remove_horizontal_overlap_block_which_smaller(all_bboxes):
return is_useful_block_horz_overlap, all_bboxes return is_useful_block_horz_overlap, all_bboxes
def __replace_STX_ETX(text_str:str): def __replace_STX_ETX(text_str: str):
""" Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks. """Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
Drawback: This issue is only observed in English text; it has not been found in Chinese text so far. Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
Args: Args:
text_str (str): raw text text_str (str): raw text
...@@ -53,15 +62,15 @@ Drawback: This issue is only observed in English text; it has not been found in ...@@ -53,15 +62,15 @@ Drawback: This issue is only observed in English text; it has not been found in
""" """
if text_str: if text_str:
s = text_str.replace('\u0002', "'") s = text_str.replace('\u0002', "'")
s = s.replace("\u0003", "'") s = s.replace('\u0003', "'")
return s return s
return text_str return text_str
def txt_spans_extract(pdf_page, inline_equations, interline_equations): def txt_spans_extract(pdf_page, inline_equations, interline_equations):
text_raw_blocks = pdf_page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"] text_raw_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
char_level_text_blocks = pdf_page.get_text("rawdict", flags=fitz.TEXTFLAGS_TEXT)[ char_level_text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)[
"blocks" 'blocks'
] ]
text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks) text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks)
text_blocks = replace_equations_in_textblock( text_blocks = replace_equations_in_textblock(
...@@ -71,189 +80,254 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations): ...@@ -71,189 +80,254 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations):
text_blocks = remove_chars_in_text_blocks(text_blocks) text_blocks = remove_chars_in_text_blocks(text_blocks)
spans = [] spans = []
for v in text_blocks: for v in text_blocks:
for line in v["lines"]: for line in v['lines']:
for span in line["spans"]: for span in line['spans']:
bbox = span["bbox"] bbox = span['bbox']
if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]): if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]):
continue continue
if span.get('type') not in (ContentType.InlineEquation, ContentType.InterlineEquation): if span.get('type') not in (
ContentType.InlineEquation,
ContentType.InterlineEquation,
):
spans.append( spans.append(
{ {
"bbox": list(span["bbox"]), 'bbox': list(span['bbox']),
"content": __replace_STX_ETX(span["text"]), 'content': __replace_STX_ETX(span['text']),
"type": ContentType.Text, 'type': ContentType.Text,
"score": 1.0, 'score': 1.0,
} }
) )
return spans return spans
def replace_text_span(pymu_spans, ocr_spans): def replace_text_span(pymu_spans, ocr_spans):
return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans return list(filter(lambda x: x['type'] != ContentType.Text, ocr_spans)) + pymu_spans
def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode): def parse_page_core(
pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
):
need_drop = False need_drop = False
drop_reason = [] drop_reason = []
'''从magic_model对象中获取后面会用到的区块信息''' """从magic_model对象中获取后面会用到的区块信息"""
img_blocks = magic_model.get_imgs(page_id) img_blocks = magic_model.get_imgs(page_id)
table_blocks = magic_model.get_tables(page_id) table_blocks = magic_model.get_tables(page_id)
discarded_blocks = magic_model.get_discarded(page_id) discarded_blocks = magic_model.get_discarded(page_id)
text_blocks = magic_model.get_text_blocks(page_id) text_blocks = magic_model.get_text_blocks(page_id)
title_blocks = magic_model.get_title_blocks(page_id) title_blocks = magic_model.get_title_blocks(page_id)
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id) inline_equations, interline_equations, interline_equation_blocks = (
magic_model.get_equations(page_id)
)
page_w, page_h = magic_model.get_page_size(page_id) page_w, page_h = magic_model.get_page_size(page_id)
spans = magic_model.get_all_spans(page_id) spans = magic_model.get_all_spans(page_id)
'''根据parse_mode,构造spans''' """根据parse_mode,构造spans"""
if parse_mode == "txt": if parse_mode == 'txt':
"""ocr 中文本类的 span 用 pymu spans 替换!""" """ocr 中文本类的 span 用 pymu spans 替换!"""
pymu_spans = txt_spans_extract( pymu_spans = txt_spans_extract(
pdf_docs[page_id], inline_equations, interline_equations pdf_docs[page_id], inline_equations, interline_equations
) )
spans = replace_text_span(pymu_spans, spans) spans = replace_text_span(pymu_spans, spans)
elif parse_mode == "ocr": elif parse_mode == 'ocr':
pass pass
else: else:
raise Exception("parse_mode must be txt or ocr") raise Exception('parse_mode must be txt or ocr')
'''删除重叠spans中置信度较低的那些''' """删除重叠spans中置信度较低的那些"""
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans) spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
'''删除重叠spans中较小的那些''' """删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans) spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
'''对image和table截图''' """对image和table截图"""
spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter) spans = ocr_cut_image_and_table(
spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter
)
'''将所有区块的bbox整理到一起''' """将所有区块的bbox整理到一起"""
# interline_equation_blocks参数不够准,后面切换到interline_equations上 # interline_equation_blocks参数不够准,后面切换到interline_equations上
interline_equation_blocks = [] interline_equation_blocks = []
if len(interline_equation_blocks) > 0: if len(interline_equation_blocks) > 0:
all_bboxes, all_discarded_blocks, drop_reasons = ocr_prepare_bboxes_for_layout_split( all_bboxes, all_discarded_blocks, drop_reasons = (
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks, ocr_prepare_bboxes_for_layout_split(
interline_equation_blocks, page_w, page_h) img_blocks,
table_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equation_blocks,
page_w,
page_h,
)
)
else: else:
all_bboxes, all_discarded_blocks, drop_reasons = ocr_prepare_bboxes_for_layout_split( all_bboxes, all_discarded_blocks, drop_reasons = (
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks, ocr_prepare_bboxes_for_layout_split(
interline_equations, page_w, page_h) img_blocks,
table_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equations,
page_w,
page_h,
)
)
if len(drop_reasons) > 0: if len(drop_reasons) > 0:
need_drop = True need_drop = True
drop_reason.append(DropReason.OVERLAP_BLOCKS_CAN_NOT_SEPARATION) drop_reason.append(DropReason.OVERLAP_BLOCKS_CAN_NOT_SEPARATION)
'''先处理不需要排版的discarded_blocks''' """先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks(all_discarded_blocks, spans, 0.4) discarded_block_with_spans, spans = fill_spans_in_blocks(
all_discarded_blocks, spans, 0.4
)
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans) fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
'''如果当前页面没有bbox则跳过''' """如果当前页面没有bbox则跳过"""
if len(all_bboxes) == 0: if len(all_bboxes) == 0:
logger.warning(f"skip this page, not found useful bbox, page_id: {page_id}") logger.warning(f'skip this page, not found useful bbox, page_id: {page_id}')
return ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [], return ocr_construct_page_component_v2(
[], [], interline_equations, fix_discarded_blocks, [],
need_drop, drop_reason) [],
page_id,
page_w,
page_h,
[],
[],
[],
interline_equations,
fix_discarded_blocks,
need_drop,
drop_reason,
)
"""在切分之前,先检查一下bbox是否有左右重叠的情况,如果有,那么就认为这个pdf暂时没有能力处理好,这种左右重叠的情况大概率是由于pdf里的行间公式、表格没有被正确识别出来造成的 """ """在切分之前,先检查一下bbox是否有左右重叠的情况,如果有,那么就认为这个pdf暂时没有能力处理好,这种左右重叠的情况大概率是由于pdf里的行间公式、表格没有被正确识别出来造成的 """
while True: # 循环检查左右重叠的情况,如果存在就删除掉较小的那个bbox,直到不存在左右重叠的情况 while True: # 循环检查左右重叠的情况,如果存在就删除掉较小的那个bbox,直到不存在左右重叠的情况
is_useful_block_horz_overlap, all_bboxes = remove_horizontal_overlap_block_which_smaller(all_bboxes) is_useful_block_horz_overlap, all_bboxes = (
remove_horizontal_overlap_block_which_smaller(all_bboxes)
)
if is_useful_block_horz_overlap: if is_useful_block_horz_overlap:
need_drop = True need_drop = True
drop_reason.append(DropReason.USEFUL_BLOCK_HOR_OVERLAP) drop_reason.append(DropReason.USEFUL_BLOCK_HOR_OVERLAP)
else: else:
break break
'''根据区块信息计算layout''' """根据区块信息计算layout"""
page_boundry = [0, 0, page_w, page_h] page_boundry = [0, 0, page_w, page_h]
layout_bboxes, layout_tree = get_bboxes_layout(all_bboxes, page_boundry, page_id) layout_bboxes, layout_tree = get_bboxes_layout(all_bboxes, page_boundry, page_id)
if len(text_blocks) > 0 and len(all_bboxes) > 0 and len(layout_bboxes) == 0: if len(text_blocks) > 0 and len(all_bboxes) > 0 and len(layout_bboxes) == 0:
logger.warning( logger.warning(
f"skip this page, page_id: {page_id}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}") f'skip this page, page_id: {page_id}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}'
)
need_drop = True need_drop = True
drop_reason.append(DropReason.CAN_NOT_DETECT_PAGE_LAYOUT) drop_reason.append(DropReason.CAN_NOT_DETECT_PAGE_LAYOUT)
"""以下去掉复杂的布局和超过2列的布局""" """以下去掉复杂的布局和超过2列的布局"""
if any([lay["layout_label"] == LAYOUT_UNPROC for lay in layout_bboxes]): # 复杂的布局 if any(
[lay['layout_label'] == LAYOUT_UNPROC for lay in layout_bboxes]
): # 复杂的布局
logger.warning( logger.warning(
f"skip this page, page_id: {page_id}, reason: {DropReason.COMPLICATED_LAYOUT}") f'skip this page, page_id: {page_id}, reason: {DropReason.COMPLICATED_LAYOUT}'
)
need_drop = True need_drop = True
drop_reason.append(DropReason.COMPLICATED_LAYOUT) drop_reason.append(DropReason.COMPLICATED_LAYOUT)
layout_column_width = get_columns_cnt_of_layout(layout_tree) layout_column_width = get_columns_cnt_of_layout(layout_tree)
if layout_column_width > 2: # 去掉超过2列的布局pdf if layout_column_width > 2: # 去掉超过2列的布局pdf
logger.warning( logger.warning(
f"skip this page, page_id: {page_id}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}") f'skip this page, page_id: {page_id}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}'
)
need_drop = True need_drop = True
drop_reason.append(DropReason.TOO_MANY_LAYOUT_COLUMNS) drop_reason.append(DropReason.TOO_MANY_LAYOUT_COLUMNS)
'''根据layout顺序,对当前页面所有需要留下的block进行排序''' """根据layout顺序,对当前页面所有需要留下的block进行排序"""
sorted_blocks = sort_blocks_by_layout(all_bboxes, layout_bboxes) sorted_blocks = sort_blocks_by_layout(all_bboxes, layout_bboxes)
'''将span填入排好序的blocks中''' """将span填入排好序的blocks中"""
block_with_spans, spans = fill_spans_in_blocks(sorted_blocks, spans, 0.3) block_with_spans, spans = fill_spans_in_blocks(sorted_blocks, spans, 0.3)
'''对block进行fix操作''' """对block进行fix操作"""
fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks) fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks)
'''获取QA需要外置的list''' """获取QA需要外置的list"""
images, tables, interline_equations = get_qa_need_list_v2(fix_blocks) images, tables, interline_equations = get_qa_need_list_v2(fix_blocks)
'''构造pdf_info_dict''' """构造pdf_info_dict"""
page_info = ocr_construct_page_component_v2(fix_blocks, layout_bboxes, page_id, page_w, page_h, layout_tree, page_info = ocr_construct_page_component_v2(
images, tables, interline_equations, fix_discarded_blocks, fix_blocks,
need_drop, drop_reason) layout_bboxes,
page_id,
page_w,
page_h,
layout_tree,
images,
tables,
interline_equations,
fix_discarded_blocks,
need_drop,
drop_reason,
)
return page_info return page_info
def pdf_parse_union(pdf_bytes, def pdf_parse_union(
pdf_bytes,
model_list, model_list,
imageWriter, imageWriter,
parse_mode, parse_mode,
start_page_id=0, start_page_id=0,
end_page_id=None, end_page_id=None,
debug_mode=False, debug_mode=False,
): ):
pdf_bytes_md5 = compute_md5(pdf_bytes) pdf_bytes_md5 = compute_md5(pdf_bytes)
pdf_docs = fitz.open("pdf", pdf_bytes) pdf_docs = fitz.open('pdf', pdf_bytes)
'''初始化空的pdf_info_dict''' """初始化空的pdf_info_dict"""
pdf_info_dict = {} pdf_info_dict = {}
'''用model_list和docs对象初始化magic_model''' """用model_list和docs对象初始化magic_model"""
magic_model = MagicModel(model_list, pdf_docs) magic_model = MagicModel(model_list, pdf_docs)
'''根据输入的起始范围解析pdf''' """根据输入的起始范围解析pdf"""
# end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1 # end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf_docs) - 1 end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else len(pdf_docs) - 1
)
if end_page_id > len(pdf_docs) - 1: if end_page_id > len(pdf_docs) - 1:
logger.warning("end_page_id is out of range, use pdf_docs length") logger.warning('end_page_id is out of range, use pdf_docs length')
end_page_id = len(pdf_docs) - 1 end_page_id = len(pdf_docs) - 1
'''初始化启动时间''' """初始化启动时间"""
start_time = time.time() start_time = time.time()
for page_id, page in enumerate(pdf_docs): for page_id, page in enumerate(pdf_docs):
'''debug时输出每页解析的耗时''' """debug时输出每页解析的耗时."""
if debug_mode: if debug_mode:
time_now = time.time() time_now = time.time()
logger.info( logger.info(
f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}" f'page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}'
) )
start_time = time_now start_time = time_now
'''解析pdf中的每一页''' """解析pdf中的每一页"""
if start_page_id <= page_id <= end_page_id: if start_page_id <= page_id <= end_page_id:
page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode) page_info = parse_page_core(
pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
)
else: else:
page_w = page.rect.width page_w = page.rect.width
page_h = page.rect.height page_h = page.rect.height
page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [], page_info = ocr_construct_page_component_v2(
[], [], [], [], [], [], page_id, page_w, page_h, [], [], [], [], [], True, 'skip page'
True, "skip page") )
pdf_info_dict[f"page_{page_id}"] = page_info pdf_info_dict[f'page_{page_id}'] = page_info
"""分段""" """分段"""
para_split(pdf_info_dict, debug_mode=debug_mode) para_split(pdf_info_dict, debug_mode=debug_mode)
...@@ -261,7 +335,7 @@ def pdf_parse_union(pdf_bytes, ...@@ -261,7 +335,7 @@ def pdf_parse_union(pdf_bytes,
"""dict转list""" """dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict) pdf_info_list = dict_to_list(pdf_info_dict)
new_pdf_info_dict = { new_pdf_info_dict = {
"pdf_info": pdf_info_list, 'pdf_info': pdf_info_list,
} }
return new_pdf_info_dict return new_pdf_info_dict
......
This diff is collapsed.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from magic_pdf.config.drop_reason import DropReason
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.dict2md.ocr_mkcontent import union_make from magic_pdf.dict2md.ocr_mkcontent import union_make
from magic_pdf.filter.pdf_classify_by_type import classify from magic_pdf.filter.pdf_classify_by_type import classify
from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
from magic_pdf.libs.MakeContentConfig import MakeMode, DropMode
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.json_compressor import JsonCompressor from magic_pdf.libs.json_compressor import JsonCompressor
class AbsPipe(ABC): class AbsPipe(ABC):
""" """txt和ocr处理的抽象类."""
txt和ocr处理的抽象类 PIP_OCR = 'ocr'
""" PIP_TXT = 'txt'
PIP_OCR = "ocr"
PIP_TXT = "txt"
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None): start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
self.pdf_bytes = pdf_bytes self.pdf_bytes = pdf_bytes
self.model_list = model_list self.model_list = model_list
...@@ -35,23 +33,17 @@ class AbsPipe(ABC): ...@@ -35,23 +33,17 @@ class AbsPipe(ABC):
@abstractmethod @abstractmethod
def pipe_classify(self): def pipe_classify(self):
""" """有状态的分类."""
有状态的分类
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def pipe_analyze(self): def pipe_analyze(self):
""" """有状态的跑模型分析."""
有状态的跑模型分析
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def pipe_parse(self): def pipe_parse(self):
""" """有状态的解析."""
有状态的解析
"""
raise NotImplementedError raise NotImplementedError
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
...@@ -64,27 +56,25 @@ class AbsPipe(ABC): ...@@ -64,27 +56,25 @@ class AbsPipe(ABC):
@staticmethod @staticmethod
def classify(pdf_bytes: bytes) -> str: def classify(pdf_bytes: bytes) -> str:
""" """根据pdf的元数据,判断是文本pdf,还是ocr pdf."""
根据pdf的元数据,判断是文本pdf,还是ocr pdf
"""
pdf_meta = pdf_meta_scan(pdf_bytes) pdf_meta = pdf_meta_scan(pdf_bytes)
if pdf_meta.get("_need_drop", False): # 如果返回了需要丢弃的标志,则抛出异常 if pdf_meta.get('_need_drop', False): # 如果返回了需要丢弃的标志,则抛出异常
raise Exception(f"pdf meta_scan need_drop,reason is {pdf_meta['_drop_reason']}") raise Exception(f"pdf meta_scan need_drop,reason is {pdf_meta['_drop_reason']}")
else: else:
is_encrypted = pdf_meta["is_encrypted"] is_encrypted = pdf_meta['is_encrypted']
is_needs_password = pdf_meta["is_needs_password"] is_needs_password = pdf_meta['is_needs_password']
if is_encrypted or is_needs_password: # 加密的,需要密码的,没有页面的,都不处理 if is_encrypted or is_needs_password: # 加密的,需要密码的,没有页面的,都不处理
raise Exception(f"pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}") raise Exception(f'pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}')
else: else:
is_text_pdf, results = classify( is_text_pdf, results = classify(
pdf_meta["total_page"], pdf_meta['total_page'],
pdf_meta["page_width_pts"], pdf_meta['page_width_pts'],
pdf_meta["page_height_pts"], pdf_meta['page_height_pts'],
pdf_meta["image_info_per_page"], pdf_meta['image_info_per_page'],
pdf_meta["text_len_per_page"], pdf_meta['text_len_per_page'],
pdf_meta["imgs_per_page"], pdf_meta['imgs_per_page'],
pdf_meta["text_layout_per_page"], pdf_meta['text_layout_per_page'],
pdf_meta["invalid_chars"], pdf_meta['invalid_chars'],
) )
if is_text_pdf: if is_text_pdf:
return AbsPipe.PIP_TXT return AbsPipe.PIP_TXT
...@@ -93,22 +83,16 @@ class AbsPipe(ABC): ...@@ -93,22 +83,16 @@ class AbsPipe(ABC):
@staticmethod @staticmethod
def mk_uni_format(compressed_pdf_mid_data: str, img_buket_path: str, drop_mode=DropMode.WHOLE_PDF) -> list: def mk_uni_format(compressed_pdf_mid_data: str, img_buket_path: str, drop_mode=DropMode.WHOLE_PDF) -> list:
""" """根据pdf类型,生成统一格式content_list."""
根据pdf类型,生成统一格式content_list
"""
pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data) pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
pdf_info_list = pdf_mid_data["pdf_info"] pdf_info_list = pdf_mid_data['pdf_info']
content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path) content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path)
return content_list return content_list
@staticmethod @staticmethod
def mk_markdown(compressed_pdf_mid_data: str, img_buket_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD) -> list: def mk_markdown(compressed_pdf_mid_data: str, img_buket_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD) -> list:
""" """根据pdf类型,markdown."""
根据pdf类型,markdown
"""
pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data) pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
pdf_info_list = pdf_mid_data["pdf_info"] pdf_info_list = pdf_mid_data['pdf_info']
md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path) md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path)
return md_content return md_content
from loguru import logger from loguru import logger
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.pipe.AbsPipe import AbsPipe from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_ocr_pdf from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe): class OCRPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False, def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None, start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None): layout_model=None, formula_enable=None, table_enable=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang, super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
...@@ -32,10 +32,10 @@ class OCRPipe(AbsPipe): ...@@ -32,10 +32,10 @@ class OCRPipe(AbsPipe):
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF): def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode) result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info("ocr_pipe mk content list finished") logger.info('ocr_pipe mk content list finished')
return result return result
def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD): def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode) result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
logger.info(f"ocr_pipe mk {md_make_mode} finished") logger.info(f'ocr_pipe mk {md_make_mode} finished')
return result return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment