Unverified Commit 158e556b authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1063 from opendatalab/release-0.10.0

Release 0.10.0
parents 038f48d3 30be5017
COLOR_BG_HEADER_TXT_BLOCK = "color_background_header_txt_block"
PAGE_NO = "page-no" # 页码
CONTENT_IN_FOOT_OR_HEADER = 'in-foot-header-area' # 页眉页脚内的文本
VERTICAL_TEXT = 'vertical-text' # 垂直文本
ROTATE_TEXT = 'rotate-text' # 旋转文本
EMPTY_SIDE_BLOCK = 'empty-side-block' # 边缘上的空白没有任何内容的block
ON_IMAGE_TEXT = 'on-image-text' # 文本在图片上
ON_TABLE_TEXT = 'on-table-text' # 文本在表格上
class DropTag:
PAGE_NUMBER = "page_no"
HEADER = "header"
FOOTER = "footer"
FOOTNOTE = "footnote"
NOT_IN_LAYOUT = "not_in_layout"
SPAN_OVERLAP = "span_overlap"
BLOCK_OVERLAP = "block_overlap"
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.libs.commons import fitz
from magic_pdf.libs.commons import join_path
from io import BytesIO
import cv2
import numpy as np
from PIL import Image
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.hash_utils import compute_sha256
def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter: AbsReaderWriter):
"""
从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径
save_path:需要同时支持s3和本地, 图片存放在save_path下,文件名是: {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。
"""
def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter: DataWriter):
"""从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 save_path:需要同时支持s3和本地,
图片存放在save_path下,文件名是:
{page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。"""
# 拼接文件名
filename = f"{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}"
filename = f'{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}'
# 老版本返回不带bucket的路径
img_path = join_path(return_path, filename) if return_path is not None else None
# 新版本生成平铺路径
img_hash256_path = f"{compute_sha256(img_path)}.jpg"
img_hash256_path = f'{compute_sha256(img_path)}.jpg'
# 将坐标转换为fitz.Rect对象
rect = fitz.Rect(*bbox)
......@@ -28,6 +29,29 @@ def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWri
byte_data = pix.tobytes(output='jpeg', jpg_quality=95)
imageWriter.write(byte_data, img_hash256_path, AbsReaderWriter.MODE_BIN)
imageWriter.write(img_hash256_path, byte_data)
return img_hash256_path
def cut_image_to_pil_image(bbox: tuple, page: fitz.Page, mode="pillow"):
# 将坐标转换为fitz.Rect对象
rect = fitz.Rect(*bbox)
# 配置缩放倍数为3倍
zoom = fitz.Matrix(3, 3)
# 截取图片
pix = page.get_pixmap(clip=rect, matrix=zoom)
# 将字节数据转换为文件对象
image_file = BytesIO(pix.tobytes(output='png'))
# 使用 Pillow 打开图像
pil_image = Image.open(image_file)
if mode == "cv2":
image_result = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2BGR)
elif mode == "pillow":
image_result = pil_image
else:
raise ValueError(f"mode: {mode} is not supported.")
return image_result
\ No newline at end of file
......@@ -163,7 +163,9 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
page_width = img_dict["width"]
page_height = img_dict["height"]
if start_page_id <= index <= end_page_id:
page_start = time.time()
result = custom_model(img)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
else:
result = []
page_info = {"page_no": index, "height": page_height, "width": page_width}
......
import enum
import json
from magic_pdf.config.model_block_type import ModelBlockTypeEnum
from magic_pdf.config.ocr_content_type import CategoryId, ContentType
from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
FileBasedDataWriter)
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
bbox_relative_pos, box_area, calculate_iou,
......@@ -9,11 +13,7 @@ from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
from magic_pdf.libs.commons import fitz, join_path
from magic_pdf.libs.coordinate_transform import get_scale_ratio
from magic_pdf.libs.local_math import float_gt
from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
CAPATION_OVERLAP_AREA_RATIO = 0.6
MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
......@@ -1050,27 +1050,27 @@ class MagicModel:
if __name__ == '__main__':
drw = DiskReaderWriter(r'D:/project/20231108code-clean')
drw = FileBasedDataReader(r'D:/project/20231108code-clean')
if 0:
pdf_file_path = r'linshixuqiu\19983-00.pdf'
model_file_path = r'linshixuqiu\19983-00_new.json'
pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN)
model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT)
pdf_bytes = drw.read(pdf_file_path)
model_json_txt = drw.read(model_file_path).decode()
model_list = json.loads(model_json_txt)
write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00'
img_bucket_path = 'imgs'
img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path))
img_writer = FileBasedDataWriter(join_path(write_path, img_bucket_path))
pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs)
if 1:
from magic_pdf.data.dataset import PymuDocDataset
model_list = json.loads(
drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.json')
)
pdf_bytes = drw.read(
'/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf', AbsReaderWriter.MODE_BIN
)
pdf_docs = fitz.open('pdf', pdf_bytes)
magic_model = MagicModel(model_list, pdf_docs)
pdf_bytes = drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf')
magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
for i in range(7):
print(magic_model.get_imgs(i))
import numpy as np
import torch
from loguru import logger
# flake8: noqa
import os
import time
import cv2
import numpy as np
import torch
import yaml
from loguru import logger
from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
......@@ -13,16 +15,18 @@ os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try:
import torchtext
if torchtext.__version__ >= "0.18.0":
if torchtext.__version__ >= '0.18.0':
torchtext.disable_torchtext_deprecation_warning()
except ImportError:
pass
from magic_pdf.libs.Constants import *
from magic_pdf.config.constants import *
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
class CustomPEKModel:
......@@ -41,42 +45,54 @@ class CustomPEKModel:
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
# 构建 model_configs.yaml 文件的完整路径
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
with open(config_path, "r", encoding='utf-8') as f:
with open(config_path, 'r', encoding='utf-8') as f:
self.configs = yaml.load(f, Loader=yaml.FullLoader)
# 初始化解析配置
# layout config
self.layout_config = kwargs.get("layout_config")
self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
self.layout_config = kwargs.get('layout_config')
self.layout_model_name = self.layout_config.get(
'model', MODEL_NAME.DocLayout_YOLO
)
# formula config
self.formula_config = kwargs.get("formula_config")
self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
self.apply_formula = self.formula_config.get("enable", True)
self.formula_config = kwargs.get('formula_config')
self.mfd_model_name = self.formula_config.get(
'mfd_model', MODEL_NAME.YOLO_V8_MFD
)
self.mfr_model_name = self.formula_config.get(
'mfr_model', MODEL_NAME.UniMerNet_v2_Small
)
self.apply_formula = self.formula_config.get('enable', True)
# table config
self.table_config = kwargs.get("table_config")
self.apply_table = self.table_config.get("enable", False)
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE)
self.table_config = kwargs.get('table_config')
self.apply_table = self.table_config.get('enable', False)
self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
# ocr config
self.apply_ocr = ocr
self.lang = kwargs.get("lang", None)
self.lang = kwargs.get('lang', None)
logger.info(
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
"apply_table: {}, table_model: {}, lang: {}".format(
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
self.lang
'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
'apply_table: {}, table_model: {}, lang: {}'.format(
self.layout_model_name,
self.apply_formula,
self.apply_ocr,
self.apply_table,
self.table_model_name,
self.lang,
)
)
# 初始化解析方案
self.device = kwargs.get("device", "cpu")
logger.info("using device: {}".format(self.device))
models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
logger.info("using models_dir: {}".format(models_dir))
self.device = kwargs.get('device', 'cpu')
logger.info('using device: {}'.format(self.device))
models_dir = kwargs.get(
'models_dir', os.path.join(root_dir, 'resources', 'models')
)
logger.info('using models_dir: {}'.format(models_dir))
atom_model_manager = AtomModelSingleton()
......@@ -85,18 +101,24 @@ class CustomPEKModel:
# 初始化公式检测模型
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])),
device=self.device
mfd_weights=str(
os.path.join(
models_dir, self.configs['weights'][self.mfd_model_name]
)
),
device=self.device,
)
# 初始化公式解析模型
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
mfr_weight_dir = str(
os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
)
mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path,
device=self.device
device=self.device,
)
# 初始化layout模型
......@@ -104,42 +126,51 @@ class CustomPEKModel:
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.LAYOUTLMv3,
layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
device=self.device
layout_weights=str(
os.path.join(
models_dir, self.configs['weights'][self.layout_model_name]
)
),
layout_config_file=str(
os.path.join(
model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
)
),
device=self.device,
)
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.DocLayout_YOLO,
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
device=self.device
doclayout_yolo_weights=str(
os.path.join(
models_dir, self.configs['weights'][self.layout_model_name]
)
),
device=self.device,
)
# 初始化ocr
if self.apply_ocr:
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3,
lang=self.lang
)
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3,
lang=self.lang
)
# init table model
if self.apply_table:
table_model_dir = self.configs["weights"][self.table_model_name]
table_model_dir = self.configs['weights'][self.table_model_name]
self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
table_model_name=self.table_model_name,
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device
device=self.device,
)
logger.info('DocAnalysis init done!')
def __call__(self, image):
page_start = time.time()
# layout检测
layout_start = time.time()
layout_res = []
......@@ -150,7 +181,7 @@ class CustomPEKModel:
# doclayout_yolo
layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection time: {layout_cost}")
logger.info(f'layout detection time: {layout_cost}')
pil_img = Image.fromarray(image)
......@@ -158,40 +189,47 @@ class CustomPEKModel:
# 公式检测
mfd_start = time.time()
mfd_res = self.mfd_model.predict(image)
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
# 公式识别
mfr_start = time.time()
formula_list = self.mfr_model.predict(mfd_res, image)
layout_res.extend(formula_list)
mfr_cost = round(time.time() - mfr_start, 2)
logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}")
logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
# 清理显存
clean_vram(self.device, vram_threshold=8)
# 从layout_res中获取ocr区域、表格区域、公式区域
ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res)
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
get_res_list_from_layout_res(layout_res)
)
# ocr识别
if self.apply_ocr:
ocr_start = time.time()
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
# OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
ocr_start = time.time()
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
# OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
if self.apply_ocr:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
else:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
layout_res.extend(ocr_result_list)
# Integration results
if ocr_res:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
layout_res.extend(ocr_result_list)
ocr_cost = round(time.time() - ocr_start, 2)
ocr_cost = round(time.time() - ocr_start, 2)
if self.apply_ocr:
logger.info(f"ocr time: {ocr_cost}")
else:
logger.info(f"det time: {ocr_cost}")
# 表格识别 table recognition
if self.apply_table:
......@@ -202,27 +240,35 @@ class CustomPEKModel:
html_code = None
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad():
table_result = self.table_model.predict(new_image, "html")
table_result = self.table_model.predict(new_image, 'html')
if len(table_result) > 0:
html_code = table_result[0]
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image)
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image)
html_code, table_cell_bboxes, elapse = self.table_model.predict(
new_image
)
run_time = time.time() - single_table_start_time
if run_time > self.table_max_time:
logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s")
logger.warning(
f'table recognition processing exceeds max time {self.table_max_time}s'
)
# 判断是否返回正常
if html_code:
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
expected_ending = html_code.strip().endswith(
'</html>'
) or html_code.strip().endswith('</table>')
if expected_ending:
res["html"] = html_code
res['html'] = html_code
else:
logger.warning(f"table recognition processing fails, not found expected HTML table end")
logger.warning(
'table recognition processing fails, not found expected HTML table end'
)
else:
logger.warning(f"table recognition processing fails, not get html return")
logger.info(f"table time: {round(time.time() - table_start, 2)}")
logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
logger.warning(
'table recognition processing fails, not get html return'
)
logger.info(f'table time: {round(time.time() - table_start, 2)}')
return layout_res
from loguru import logger
from magic_pdf.libs.Constants import MODEL_NAME
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
DocLayoutYOLOModel
from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
Layoutlmv3_Predictor
from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import \
ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import \
RapidTableModel
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import TableMasterPaddleModel
from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import \
StructTableModel
from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
......@@ -19,14 +24,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER:
config = {
"model_dir": model_path,
"device": _device_
'model_dir': model_path,
'device': _device_
}
table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTableModel()
else:
logger.error("table model type not allow")
logger.error('table model type not allow')
exit(1)
return table_model
......@@ -58,7 +63,7 @@ def ocr_model_init(show_log: bool = False,
use_dilation=True,
det_db_unclip_ratio=1.8,
):
if lang is not None:
if lang is not None and lang != '':
model = ModifiedPaddleOCR(
show_log=show_log,
det_db_box_thresh=det_db_box_thresh,
......@@ -87,8 +92,8 @@ class AtomModelSingleton:
return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs):
lang = kwargs.get("lang", None)
layout_model_name = kwargs.get("layout_model_name", None)
lang = kwargs.get('lang', None)
layout_model_name = kwargs.get('layout_model_name', None)
key = (atom_model_name, layout_model_name, lang)
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
......@@ -98,47 +103,47 @@ class AtomModelSingleton:
def atom_model_init(model_name: str, **kwargs):
atom_model = None
if model_name == AtomicModel.Layout:
if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
if kwargs.get('layout_model_name') == MODEL_NAME.LAYOUTLMv3:
atom_model = layout_model_init(
kwargs.get("layout_weights"),
kwargs.get("layout_config_file"),
kwargs.get("device")
kwargs.get('layout_weights'),
kwargs.get('layout_config_file'),
kwargs.get('device')
)
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
elif kwargs.get('layout_model_name') == MODEL_NAME.DocLayout_YOLO:
atom_model = doclayout_yolo_model_init(
kwargs.get("doclayout_yolo_weights"),
kwargs.get("device")
kwargs.get('doclayout_yolo_weights'),
kwargs.get('device')
)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get("mfd_weights"),
kwargs.get("device")
kwargs.get('mfd_weights'),
kwargs.get('device')
)
elif model_name == AtomicModel.MFR:
atom_model = mfr_model_init(
kwargs.get("mfr_weight_dir"),
kwargs.get("mfr_cfg_path"),
kwargs.get("device")
kwargs.get('mfr_weight_dir'),
kwargs.get('mfr_cfg_path'),
kwargs.get('device')
)
elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init(
kwargs.get("ocr_show_log"),
kwargs.get("det_db_box_thresh"),
kwargs.get("lang")
kwargs.get('ocr_show_log'),
kwargs.get('det_db_box_thresh'),
kwargs.get('lang')
)
elif model_name == AtomicModel.Table:
atom_model = table_model_init(
kwargs.get("table_model_name"),
kwargs.get("table_model_path"),
kwargs.get("table_max_time"),
kwargs.get("device")
kwargs.get('table_model_name'),
kwargs.get('table_model_path'),
kwargs.get('table_max_time'),
kwargs.get('device')
)
else:
logger.error("model name not allow")
logger.error('model name not allow')
exit(1)
if atom_model is None:
logger.error("model init failed")
logger.error('model init failed')
exit(1)
else:
return atom_model
......@@ -71,7 +71,13 @@ def remove_intervals(original, masks):
def update_det_boxes(dt_boxes, mfd_res):
new_dt_boxes = []
angle_boxes_list = []
for text_box in dt_boxes:
if calculate_is_angle(text_box):
angle_boxes_list.append(text_box)
continue
text_bbox = points_to_bbox(text_box)
masks_list = []
for mf_box in mfd_res:
......@@ -85,6 +91,9 @@ def update_det_boxes(dt_boxes, mfd_res):
temp_dt_box.append(bbox_to_points([text_remove_mask[0], text_bbox[1], text_remove_mask[1], text_bbox[3]]))
if len(temp_dt_box) > 0:
new_dt_boxes.extend(temp_dt_box)
new_dt_boxes.extend(angle_boxes_list)
return new_dt_boxes
......@@ -143,9 +152,11 @@ def merge_det_boxes(dt_boxes):
angle_boxes_list = []
for text_box in dt_boxes:
text_bbox = points_to_bbox(text_box)
if text_bbox[2] <= text_bbox[0] or text_bbox[3] <= text_bbox[1]:
if calculate_is_angle(text_box):
angle_boxes_list.append(text_box)
continue
text_box_dict = {
'bbox': text_bbox,
'type': 'text',
......@@ -200,15 +211,21 @@ def get_ocr_result_list(ocr_res, useful_list):
ocr_result_list = []
for box_ocr_res in ocr_res:
p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1]
average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
if average_angle_degrees > 0.5:
if len(box_ocr_res) == 2:
p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1]
else:
p1, p2, p3, p4 = box_ocr_res
text, score = "", 1
# average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
# if average_angle_degrees > 0.5:
poly = [p1, p2, p3, p4]
if calculate_is_angle(poly):
# logger.info(f"average_angle_degrees: {average_angle_degrees}, text: {text}")
# 与x轴的夹角超过0.5度,对边界做一下矫正
# 计算几何中心
x_center = sum(point[0] for point in box_ocr_res[0]) / 4
y_center = sum(point[1] for point in box_ocr_res[0]) / 4
x_center = sum(point[0] for point in poly) / 4
y_center = sum(point[1] for point in poly) / 4
new_height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
new_width = p3[0] - p1[0]
p1 = [x_center - new_width / 2, y_center - new_height / 2]
......@@ -257,3 +274,12 @@ def calculate_angle_degrees(poly):
# logger.info(f"average_angle_degrees: {average_angle_degrees}")
return average_angle_degrees
def calculate_is_angle(poly):
p1, p2, p3, p4 = poly
height = ((p4[1] - p1[1]) + (p3[1] - p2[1])) / 2
if 0.8 * height <= (p3[1] - p1[1]) <= 1.2 * height:
return False
else:
# logger.info((p3[1] - p1[1])/height)
return True
\ No newline at end of file
......@@ -78,9 +78,18 @@ class ModifiedPaddleOCR(PaddleOCR):
for idx, img in enumerate(imgs):
img = preprocess_image(img)
dt_boxes, elapse = self.text_detector(img)
if not dt_boxes:
if dt_boxes is None:
ocr_res.append(None)
continue
dt_boxes = sorted_boxes(dt_boxes)
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
dt_boxes = merge_det_boxes(dt_boxes)
if mfd_res:
bef = time.time()
dt_boxes = update_det_boxes(dt_boxes, mfd_res)
aft = time.time()
logger.debug("split text box by formula, new dt_boxes num : {}, elapsed : {}".format(
len(dt_boxes), aft - bef))
tmp_res = [box.tolist() for box in dt_boxes]
ocr_res.append(tmp_res)
return ocr_res
......@@ -125,9 +134,8 @@ class ModifiedPaddleOCR(PaddleOCR):
dt_boxes = sorted_boxes(dt_boxes)
# @todo 目前是在bbox层merge,对倾斜文本行的兼容性不佳,需要修改成支持poly的merge
# dt_boxes = merge_det_boxes(dt_boxes)
# merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
dt_boxes = merge_det_boxes(dt_boxes)
if mfd_res:
bef = time.time()
......
......@@ -10,5 +10,7 @@ class RapidTableModel(object):
def predict(self, image):
ocr_result, _ = self.ocr_engine(np.asarray(image))
if ocr_result is None:
return None, None, None
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
return html_code, table_cell_bboxes, elapse
\ No newline at end of file
import os
import cv2
import numpy as np
from paddleocr.ppstructure.table.predict_table import TableSystem
from paddleocr.ppstructure.utility import init_args
from magic_pdf.libs.Constants import *
import os
from PIL import Image
import numpy as np
from magic_pdf.config.constants import * # noqa: F403
class TableMasterPaddleModel(object):
"""
This class is responsible for converting image of table into HTML format using a pre-trained model.
"""This class is responsible for converting image of table into HTML format
using a pre-trained model.
Attributes:
- table_sys: An instance of TableSystem initialized with parsed arguments.
Attributes:
- table_sys: An instance of TableSystem initialized with parsed arguments.
Methods:
- __init__(config): Initializes the model with configuration parameters.
- img2html(image): Converts a PIL Image or NumPy array to HTML string.
- parse_args(**kwargs): Parses configuration arguments.
Methods:
- __init__(config): Initializes the model with configuration parameters.
- img2html(image): Converts a PIL Image or NumPy array to HTML string.
- parse_args(**kwargs): Parses configuration arguments.
"""
def __init__(self, config):
......@@ -40,30 +42,30 @@ class TableMasterPaddleModel(object):
image = np.asarray(image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
pred_res, _ = self.table_sys(image)
pred_html = pred_res["html"]
pred_html = pred_res['html']
# res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace(
# "</table></body></html>","") + "</table></td>\n"
return pred_html
def parse_args(self, **kwargs):
parser = init_args()
model_dir = kwargs.get("model_dir")
table_model_dir = os.path.join(model_dir, TABLE_MASTER_DIR)
table_char_dict_path = os.path.join(model_dir, TABLE_MASTER_DICT)
det_model_dir = os.path.join(model_dir, DETECT_MODEL_DIR)
rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR)
rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT)
device = kwargs.get("device", "cpu")
use_gpu = True if device.startswith("cuda") else False
model_dir = kwargs.get('model_dir')
table_model_dir = os.path.join(model_dir, TABLE_MASTER_DIR) # noqa: F405
table_char_dict_path = os.path.join(model_dir, TABLE_MASTER_DICT) # noqa: F405
det_model_dir = os.path.join(model_dir, DETECT_MODEL_DIR) # noqa: F405
rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR) # noqa: F405
rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT) # noqa: F405
device = kwargs.get('device', 'cpu')
use_gpu = True if device.startswith('cuda') else False
config = {
"use_gpu": use_gpu,
"table_max_len": kwargs.get("table_max_len", TABLE_MAX_LEN),
"table_algorithm": "TableMaster",
"table_model_dir": table_model_dir,
"table_char_dict_path": table_char_dict_path,
"det_model_dir": det_model_dir,
"rec_model_dir": rec_model_dir,
"rec_char_dict_path": rec_char_dict_path,
'use_gpu': use_gpu,
'table_max_len': kwargs.get('table_max_len', TABLE_MAX_LEN), # noqa: F405
'table_algorithm': 'TableMaster',
'table_model_dir': table_model_dir,
'table_char_dict_path': table_char_dict_path,
'det_model_dir': det_model_dir,
'rec_model_dir': rec_model_dir,
'rec_char_dict_path': rec_char_dict_path,
}
parser.set_defaults(**config)
return parser.parse_args([])
import os
import json
from magic_pdf.para.commons import *
from magic_pdf.para.raw_processor import RawBlockProcessor
from magic_pdf.para.layout_match_processor import LayoutFilterProcessor
from magic_pdf.para.stats import BlockStatisticsCalculator
from magic_pdf.para.stats import DocStatisticsCalculator
from magic_pdf.para.title_processor import TitleProcessor
from magic_pdf.para.block_termination_processor import BlockTerminationProcessor
from magic_pdf.para.block_continuation_processor import BlockContinuationProcessor
from magic_pdf.para.draw import DrawAnnos
from magic_pdf.para.exceptions import (
DenseSingleLineBlockException,
TitleDetectionException,
TitleLevelException,
ParaSplitException,
ParaMergeException,
DiscardByException,
)
if sys.version_info[0] >= 3:
sys.stdout.reconfigure(encoding="utf-8") # type: ignore
class ParaProcessPipeline:
def __init__(self) -> None:
pass
def para_process_pipeline(self, pdf_info_dict, para_debug_mode=None, input_pdf_path=None, output_pdf_path=None):
"""
This function processes the paragraphs, including:
1. Read raw input json file into pdf_dic
2. Detect and replace equations
3. Combine spans into a natural line
4. Check if the paragraphs are inside bboxes passed from "layout_bboxes" key
5. Compute statistics for each block
6. Detect titles in the document
7. Detect paragraphs inside each block
8. Divide the level of the titles
9. Detect and combine paragraphs from different blocks into one paragraph
10. Check whether the final results after checking headings, dividing paragraphs within blocks, and merging paragraphs between blocks are plausible and reasonable.
11. Draw annotations on the pdf file
Parameters
----------
pdf_dic_json_fpath : str
path to the pdf dictionary json file.
Notice: data noises, including overlap blocks, header, footer, watermark, vertical margin note have been removed already.
input_pdf_doc : str
path to the input pdf file
output_pdf_path : str
path to the output pdf file
Returns
-------
pdf_dict : dict
result dictionary
"""
error_info = None
output_json_file = ""
output_dir = ""
if input_pdf_path is not None:
input_pdf_path = os.path.abspath(input_pdf_path)
# print_green_on_red(f">>>>>>>>>>>>>>>>>>> Process the paragraphs of {input_pdf_path}")
if output_pdf_path is not None:
output_dir = os.path.dirname(output_pdf_path)
output_json_file = f"{output_dir}/pdf_dic.json"
def __save_pdf_dic(pdf_dic, output_pdf_path, stage="0", para_debug_mode=para_debug_mode):
"""
Save the pdf_dic to a json file
"""
output_pdf_file_name = os.path.basename(output_pdf_path)
# output_dir = os.path.dirname(output_pdf_path)
output_dir = "\\tmp\\pdf_parse"
output_pdf_file_name = output_pdf_file_name.replace(".pdf", f"_stage_{stage}.json")
pdf_dic_json_fpath = os.path.join(output_dir, output_pdf_file_name)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if para_debug_mode == "full":
with open(pdf_dic_json_fpath, "w", encoding="utf-8") as f:
json.dump(pdf_dic, f, indent=2, ensure_ascii=False)
# Validate the output already exists
if not os.path.exists(pdf_dic_json_fpath):
print_red(f"Failed to save the pdf_dic to {pdf_dic_json_fpath}")
return None
else:
print_green(f"Succeed to save the pdf_dic to {pdf_dic_json_fpath}")
return pdf_dic_json_fpath
"""
Preprocess the lines of block
"""
# Find and replace the interline and inline equations, should be better done before the paragraph processing
# Create "para_blocks" for each page.
# equationProcessor = EquationsProcessor()
# pdf_dic = equationProcessor.batch_process_blocks(pdf_info_dict)
# Combine spans into a natural line
rawBlockProcessor = RawBlockProcessor()
pdf_dic = rawBlockProcessor.batch_process_blocks(pdf_info_dict)
# print(f"pdf_dic['page_0']['para_blocks'][0]: {pdf_dic['page_0']['para_blocks'][0]}", end="\n\n")
# Check if the paragraphs are inside bboxes passed from "layout_bboxes" key
layoutFilter = LayoutFilterProcessor()
pdf_dic = layoutFilter.batch_process_blocks(pdf_dic)
# Compute statistics for each block
blockStatisticsCalculator = BlockStatisticsCalculator()
pdf_dic = blockStatisticsCalculator.batch_process_blocks(pdf_dic)
# print(f"pdf_dic['page_0']['para_blocks'][0]: {pdf_dic['page_0']['para_blocks'][0]}", end="\n\n")
# Compute statistics for all blocks(namely this pdf document)
docStatisticsCalculator = DocStatisticsCalculator()
pdf_dic = docStatisticsCalculator.calc_stats_of_doc(pdf_dic)
# print(f"pdf_dic['statistics']: {pdf_dic['statistics']}", end="\n\n")
# Dump the first three stages of pdf_dic to a json file
if para_debug_mode == "full":
pdf_dic_json_fpath = __save_pdf_dic(pdf_dic, output_pdf_path, stage="0", para_debug_mode=para_debug_mode)
"""
Detect titles in the document
"""
doc_statistics = pdf_dic["statistics"]
titleProcessor = TitleProcessor(doc_statistics)
pdf_dic = titleProcessor.batch_process_blocks_detect_titles(pdf_dic)
if para_debug_mode == "full":
pdf_dic_json_fpath = __save_pdf_dic(pdf_dic, output_pdf_path, stage="1", para_debug_mode=para_debug_mode)
"""
Detect and divide the level of the titles
"""
titleProcessor = TitleProcessor()
pdf_dic = titleProcessor.batch_process_blocks_recog_title_level(pdf_dic)
if para_debug_mode == "full":
pdf_dic_json_fpath = __save_pdf_dic(pdf_dic, output_pdf_path, stage="2", para_debug_mode=para_debug_mode)
"""
Detect and split paragraphs inside each block
"""
blockInnerParasProcessor = BlockTerminationProcessor()
pdf_dic = blockInnerParasProcessor.batch_process_blocks(pdf_dic)
if para_debug_mode == "full":
pdf_dic_json_fpath = __save_pdf_dic(pdf_dic, output_pdf_path, stage="3", para_debug_mode=para_debug_mode)
# pdf_dic_json_fpath = __save_pdf_dic(pdf_dic, output_pdf_path, stage="3", para_debug_mode="full")
# print_green(f"pdf_dic_json_fpath: {pdf_dic_json_fpath}")
"""
Detect and combine paragraphs from different blocks into one paragraph
"""
blockContinuationProcessor = BlockContinuationProcessor()
pdf_dic = blockContinuationProcessor.batch_tag_paras(pdf_dic)
pdf_dic = blockContinuationProcessor.batch_merge_paras(pdf_dic)
if para_debug_mode == "full":
pdf_dic_json_fpath = __save_pdf_dic(pdf_dic, output_pdf_path, stage="4", para_debug_mode=para_debug_mode)
# pdf_dic_json_fpath = __save_pdf_dic(pdf_dic, output_pdf_path, stage="4", para_debug_mode="full")
# print_green(f"pdf_dic_json_fpath: {pdf_dic_json_fpath}")
"""
Discard pdf files by checking exceptions and return the error info to the caller
"""
discardByException = DiscardByException()
is_discard_by_single_line_block = discardByException.discard_by_single_line_block(
pdf_dic, exception=DenseSingleLineBlockException()
)
is_discard_by_title_detection = discardByException.discard_by_title_detection(
pdf_dic, exception=TitleDetectionException()
)
is_discard_by_title_level = discardByException.discard_by_title_level(pdf_dic, exception=TitleLevelException())
is_discard_by_split_para = discardByException.discard_by_split_para(pdf_dic, exception=ParaSplitException())
is_discard_by_merge_para = discardByException.discard_by_merge_para(pdf_dic, exception=ParaMergeException())
"""
if any(
info is not None
for info in [
is_discard_by_single_line_block,
is_discard_by_title_detection,
is_discard_by_title_level,
is_discard_by_split_para,
is_discard_by_merge_para,
]
):
error_info = next(
(
info
for info in [
is_discard_by_single_line_block,
is_discard_by_title_detection,
is_discard_by_title_level,
is_discard_by_split_para,
is_discard_by_merge_para,
]
if info is not None
),
None,
)
return pdf_dic, error_info
if any(
info is not None
for info in [
is_discard_by_single_line_block,
is_discard_by_title_detection,
is_discard_by_title_level,
is_discard_by_split_para,
is_discard_by_merge_para,
]
):
error_info = next(
(
info
for info in [
is_discard_by_single_line_block,
is_discard_by_title_detection,
is_discard_by_title_level,
is_discard_by_split_para,
is_discard_by_merge_para,
]
if info is not None
),
None,
)
return pdf_dic, error_info
"""
"""
Dump the final pdf_dic to a json file
"""
if para_debug_mode is not None:
with open(output_json_file, "w", encoding="utf-8") as f:
json.dump(pdf_info_dict, f, ensure_ascii=False, indent=4)
"""
Draw the annotations
"""
if is_discard_by_single_line_block is not None:
error_info = is_discard_by_single_line_block
elif is_discard_by_title_detection is not None:
error_info = is_discard_by_title_detection
elif is_discard_by_title_level is not None:
error_info = is_discard_by_title_level
elif is_discard_by_split_para is not None:
error_info = is_discard_by_split_para
elif is_discard_by_merge_para is not None:
error_info = is_discard_by_merge_para
if error_info is not None:
return pdf_dic, error_info
"""
Dump the final pdf_dic to a json file
"""
if para_debug_mode is not None:
with open(output_json_file, "w", encoding="utf-8") as f:
json.dump(pdf_info_dict, f, ensure_ascii=False, indent=4)
"""
Draw the annotations
"""
if para_debug_mode is not None:
drawAnnos = DrawAnnos()
drawAnnos.draw_annos(input_pdf_path, pdf_dic, output_pdf_path)
"""
Remove the intermediate files which are generated in the process of paragraph processing if debug_mode is simple
"""
if para_debug_mode is not None:
for fpath in os.listdir(output_dir):
if fpath.endswith(".json") and "stage" in fpath:
os.remove(os.path.join(output_dir, fpath))
return pdf_dic, error_info
from sklearn.cluster import DBSCAN
import numpy as np
from loguru import logger
from sklearn.cluster import DBSCAN
from magic_pdf.libs.boxbase import _is_in_or_part_overlap_with_area_ratio as is_in_layout
from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.config.ocr_content_type import ContentType
from magic_pdf.libs.boxbase import \
_is_in_or_part_overlap_with_area_ratio as is_in_layout
LINE_STOP_FLAG = ['.', '!', '?', '。', '!', '?',":", ":", ")", ")", ";"]
LINE_STOP_FLAG = ['.', '!', '?', '。', '!', '?', ':', ':', ')', ')', ';']
INLINE_EQUATION = ContentType.InlineEquation
INTERLINE_EQUATION = ContentType.InterlineEquation
TEXT = ContentType.Text
......@@ -14,30 +14,36 @@ TEXT = ContentType.Text
def __get_span_text(span):
c = span.get('content', '')
if len(c)==0:
if len(c) == 0:
c = span.get('image_path', '')
return c
def __detect_list_lines(lines, new_layout_bboxes, lang):
"""
探测是否包含了列表,并且把列表的行分开.
"""探测是否包含了列表,并且把列表的行分开.
这样的段落特点是,顶格字母大写/数字,紧跟着几行缩进的。缩进的行首字母含小写的。
"""
def find_repeating_patterns(lst):
indices = []
ones_indices = []
i = 0
while i < len(lst) - 1: # 确保余下元素至少有2个
if lst[i] == 1 and lst[i+1] in [2, 3]: # 额外检查以防止连续出现的1
if lst[i] == 1 and lst[i + 1] in [2, 3]: # 额外检查以防止连续出现的1
start = i
ones_in_this_interval = [i]
i += 1
while i < len(lst) and lst[i] in [2, 3]:
i += 1
# 验证下一个序列是否符合条件
if i < len(lst) - 1 and lst[i] == 1 and lst[i+1] in [2, 3] and lst[i-1] in [2, 3]:
if (
i < len(lst) - 1
and lst[i] == 1
and lst[i + 1] in [2, 3]
and lst[i - 1] in [2, 3]
):
while i < len(lst) and lst[i] in [1, 2, 3]:
if lst[i] == 1:
ones_in_this_interval.append(i)
......@@ -49,11 +55,13 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
else:
i += 1
return indices, ones_indices
"""===================="""
def split_indices(slen, index_array):
result = []
last_end = 0
for start, end in sorted(index_array):
if start > last_end:
# 前一个区间结束到下一个区间开始之间的部分标记为"text"
......@@ -67,9 +75,10 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
result.append(('text', last_end, slen - 1))
return result
"""===================="""
if lang!='en':
if lang != 'en':
return lines, None
else:
total_lines = len(lines)
......@@ -81,7 +90,7 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
3. 如果非顶格,首字符大写,编码为2
4. 如果非顶格,首字符非大写编码为3
"""
for l in lines:
for l in lines: # noqa: E741
first_char = __get_span_text(l['spans'][0])[0]
layout_left = __find_layout_bbox_by_line(l['bbox'], new_layout_bboxes)[0]
if l['bbox'][0] == layout_left:
......@@ -94,68 +103,79 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
line_fea_encode.append(2)
else:
line_fea_encode.append(3)
# 然后根据编码进行分段, 选出来 1,2,3连续出现至少2次的行,认为是列表。
list_indice, list_start_idx = find_repeating_patterns(line_fea_encode)
if len(list_indice)>0:
logger.info(f"发现了列表,列表行数:{list_indice}{list_start_idx}")
list_indice, list_start_idx = find_repeating_patterns(line_fea_encode)
if len(list_indice) > 0:
logger.info(f'发现了列表,列表行数:{list_indice}{list_start_idx}')
# TODO check一下这个特列表里缩进的行左侧是不是对齐的。
segments = []
for start, end in list_indice:
for i in range(start, end+1):
if i>0:
for i in range(start, end + 1):
if i > 0:
if line_fea_encode[i] == 4:
logger.info(f"列表行的第{i}行不是顶格的")
logger.info(f'列表行的第{i}行不是顶格的')
break
else:
logger.info(f"列表行的第{start}到第{end}行是列表")
logger.info(f'列表行的第{start}到第{end}行是列表')
return split_indices(total_lines, list_indice), list_start_idx
def __valign_lines(blocks, layout_bboxes):
"""
在一个layoutbox内对齐行的左侧和右侧。
扫描行的左侧和右侧,如果x0, x1差距不超过一个阈值,就强行对齐到所处layout的左右两侧(和layout有一段距离)。
3是个经验值,TODO,计算得来,可以设置为1.5个正文字符。
"""
"""在一个layoutbox内对齐行的左侧和右侧。 扫描行的左侧和右侧,如果x0,
x1差距不超过一个阈值,就强行对齐到所处layout的左右两侧(和layout有一段距离)。
3是个经验值,TODO,计算得来,可以设置为1.5个正文字符。"""
min_distance = 3
min_sample = 2
new_layout_bboxes = []
for layout_box in layout_bboxes:
blocks_in_layoutbox = [b for b in blocks if is_in_layout(b['bbox'], layout_box['layout_bbox'])]
if len(blocks_in_layoutbox)==0:
blocks_in_layoutbox = [
b for b in blocks if is_in_layout(b['bbox'], layout_box['layout_bbox'])
]
if len(blocks_in_layoutbox) == 0:
continue
x0_lst = np.array([[line['bbox'][0], 0] for block in blocks_in_layoutbox for line in block['lines']])
x1_lst = np.array([[line['bbox'][2], 0] for block in blocks_in_layoutbox for line in block['lines']])
x0_lst = np.array(
[
[line['bbox'][0], 0]
for block in blocks_in_layoutbox
for line in block['lines']
]
)
x1_lst = np.array(
[
[line['bbox'][2], 0]
for block in blocks_in_layoutbox
for line in block['lines']
]
)
x0_clusters = DBSCAN(eps=min_distance, min_samples=min_sample).fit(x0_lst)
x1_clusters = DBSCAN(eps=min_distance, min_samples=min_sample).fit(x1_lst)
x0_uniq_label = np.unique(x0_clusters.labels_)
x1_uniq_label = np.unique(x1_clusters.labels_)
x0_2_new_val = {} # 存储旧值对应的新值映射
x0_2_new_val = {} # 存储旧值对应的新值映射
x1_2_new_val = {}
for label in x0_uniq_label:
if label==-1:
if label == -1:
continue
x0_index_of_label = np.where(x0_clusters.labels_==label)
x0_raw_val = x0_lst[x0_index_of_label][:,0]
x0_new_val = np.min(x0_lst[x0_index_of_label][:,0])
x0_index_of_label = np.where(x0_clusters.labels_ == label)
x0_raw_val = x0_lst[x0_index_of_label][:, 0]
x0_new_val = np.min(x0_lst[x0_index_of_label][:, 0])
x0_2_new_val.update({idx: x0_new_val for idx in x0_raw_val})
for label in x1_uniq_label:
if label==-1:
if label == -1:
continue
x1_index_of_label = np.where(x1_clusters.labels_==label)
x1_raw_val = x1_lst[x1_index_of_label][:,0]
x1_new_val = np.max(x1_lst[x1_index_of_label][:,0])
x1_index_of_label = np.where(x1_clusters.labels_ == label)
x1_raw_val = x1_lst[x1_index_of_label][:, 0]
x1_new_val = np.max(x1_lst[x1_index_of_label][:, 0])
x1_2_new_val.update({idx: x1_new_val for idx in x1_raw_val})
for block in blocks_in_layoutbox:
for line in block['lines']:
x0, x1 = line['bbox'][0], line['bbox'][2]
......@@ -165,34 +185,34 @@ def __valign_lines(blocks, layout_bboxes):
if x1 in x1_2_new_val:
line['bbox'][2] = int(x1_2_new_val[x1])
# 其余对不齐的保持不动
# 由于修改了block里的line长度,现在需要重新计算block的bbox
for block in blocks_in_layoutbox:
block['bbox'] = [min([line['bbox'][0] for line in block['lines']]),
min([line['bbox'][1] for line in block['lines']]),
max([line['bbox'][2] for line in block['lines']]),
max([line['bbox'][3] for line in block['lines']])]
block['bbox'] = [
min([line['bbox'][0] for line in block['lines']]),
min([line['bbox'][1] for line in block['lines']]),
max([line['bbox'][2] for line in block['lines']]),
max([line['bbox'][3] for line in block['lines']]),
]
"""新计算layout的bbox,因为block的bbox变了。"""
layout_x0 = min([block['bbox'][0] for block in blocks_in_layoutbox])
layout_y0 = min([block['bbox'][1] for block in blocks_in_layoutbox])
layout_x1 = max([block['bbox'][2] for block in blocks_in_layoutbox])
layout_y1 = max([block['bbox'][3] for block in blocks_in_layoutbox])
new_layout_bboxes.append([layout_x0, layout_y0, layout_x1, layout_y1])
return new_layout_bboxes
def __align_text_in_layout(blocks, layout_bboxes):
"""
由于ocr出来的line,有时候会在前后有一段空白,这个时候需要对文本进行对齐,超出的部分被layout左右侧截断。
"""
"""由于ocr出来的line,有时候会在前后有一段空白,这个时候需要对文本进行对齐,超出的部分被layout左右侧截断。"""
for layout in layout_bboxes:
lb = layout['layout_bbox']
blocks_in_layoutbox = [b for b in blocks if is_in_layout(b['bbox'], lb)]
if len(blocks_in_layoutbox)==0:
if len(blocks_in_layoutbox) == 0:
continue
for block in blocks_in_layoutbox:
for line in block['lines']:
x0, x1 = line['bbox'][0], line['bbox'][2]
......@@ -200,69 +220,67 @@ def __align_text_in_layout(blocks, layout_bboxes):
line['bbox'][0] = lb[0]
if x1 > lb[2]:
line['bbox'][2] = lb[2]
def __common_pre_proc(blocks, layout_bboxes):
"""
不分语言的,对文本进行预处理
"""
#__add_line_period(blocks, layout_bboxes)
"""不分语言的,对文本进行预处理."""
# __add_line_period(blocks, layout_bboxes)
__align_text_in_layout(blocks, layout_bboxes)
aligned_layout_bboxes = __valign_lines(blocks, layout_bboxes)
return aligned_layout_bboxes
def __pre_proc_zh_blocks(blocks, layout_bboxes):
"""
对中文文本进行分段预处理
"""
"""对中文文本进行分段预处理."""
pass
def __pre_proc_en_blocks(blocks, layout_bboxes):
"""
对英文文本进行分段预处理
"""
"""对英文文本进行分段预处理."""
pass
def __group_line_by_layout(blocks, layout_bboxes, lang="en"):
"""
每个layout内的行进行聚合
"""
def __group_line_by_layout(blocks, layout_bboxes, lang='en'):
"""每个layout内的行进行聚合."""
# 因为只是一个block一行目前, 一个block就是一个段落
lines_group = []
for lyout in layout_bboxes:
lines = [line for block in blocks if is_in_layout(block['bbox'], lyout['layout_bbox']) for line in block['lines']]
lines = [
line
for block in blocks
if is_in_layout(block['bbox'], lyout['layout_bbox'])
for line in block['lines']
]
lines_group.append(lines)
return lines_group
def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_len=10):
def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang='en', char_avg_len=10):
"""
lines_group 进行行分段——layout内部进行分段。lines_group内每个元素是一个Layoutbox内的所有行。
1. 先计算每个group的左右边界。
2. 然后根据行末尾特征进行分段。
末尾特征:以句号等结束符结尾。并且距离右侧边界有一定距离。
且下一行开头不留空白。
"""
list_info = [] # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾
list_info = [] # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾
layout_paras = []
right_tail_distance = 1.5 * char_avg_len
for lines in lines_group:
paras = []
total_lines = len(lines)
if total_lines==0:
continue # 0行无需处理
if total_lines==1: # 1行无法分段。
if total_lines == 0:
continue # 0行无需处理
if total_lines == 1: # 1行无法分段。
layout_paras.append([lines])
list_info.append([False, False])
continue
"""在进入到真正的分段之前,要对文字块从统计维度进行对齐方式的探测,
对齐方式分为以下:
1. 左对齐的文本块(特点是左侧顶格,或者左侧不顶格但是右侧顶格的行数大于非顶格的行数,顶格的首字母有大写也有小写)
......@@ -271,139 +289,183 @@ def __split_para_in_layoutbox(lines_group, new_layout_bbox, lang="en", char_avg_
2. 左对齐的列表块(其特点是左侧顶格的行数小于等于非顶格的行数,非定格首字母会有小写,顶格90%是大写。并且左侧顶格行数大于1,大于1是为了这种模式连续出现才能称之为列表)
这样的文本块,顶格的为一个段落开头,紧随其后非顶格的行属于这个段落。
"""
text_segments, list_start_line = __detect_list_lines(lines, new_layout_bbox, lang)
text_segments, list_start_line = __detect_list_lines(
lines, new_layout_bbox, lang
)
"""根据list_range,把lines分成几个部分
"""
layout_right = __find_layout_bbox_by_line(lines[0]['bbox'], new_layout_bbox)[2]
layout_left = __find_layout_bbox_by_line(lines[0]['bbox'], new_layout_bbox)[0]
para = [] # 元素是line
layout_list_info = [False, False] # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾
para = [] # 元素是line
layout_list_info = [
False,
False,
] # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾
for content_type, start, end in text_segments:
if content_type == 'list':
for i, line in enumerate(lines[start:end+1]):
for i, line in enumerate(lines[start : end + 1]):
line_x0 = line['bbox'][0]
if line_x0 == layout_left: # 列表开头
if len(para)>0:
if line_x0 == layout_left: # 列表开头
if len(para) > 0:
paras.append(para)
para = []
para.append(line)
else:
para.append(line)
if len(para)>0:
if len(para) > 0:
paras.append(para)
para = []
if start==0:
if start == 0:
layout_list_info[0] = True
if end==total_lines-1:
if end == total_lines - 1:
layout_list_info[1] = True
else: # 是普通文本
for i, line in enumerate(lines[start:end+1]):
else: # 是普通文本
for i, line in enumerate(lines[start : end + 1]):
# 如果i有下一行,那么就要根据下一行位置综合判断是否要分段。如果i之后没有行,那么只需要判断i行自己的结尾特征。
cur_line_type = line['spans'][-1]['type']
next_line = lines[i+1] if i<total_lines-1 else None
next_line = lines[i + 1] if i < total_lines - 1 else None
if cur_line_type in [TEXT, INLINE_EQUATION]:
if line['bbox'][2] < layout_right - right_tail_distance:
para.append(line)
paras.append(para)
para = []
elif line['bbox'][2] >= layout_right - right_tail_distance and next_line and next_line['bbox'][0] == layout_left: # 现在这行到了行尾沾满,下一行存在且顶格。
elif (
line['bbox'][2] >= layout_right - right_tail_distance
and next_line
and next_line['bbox'][0] == layout_left
): # 现在这行到了行尾沾满,下一行存在且顶格。
para.append(line)
else:
else:
para.append(line)
paras.append(para)
para = []
else: # 其他,图片、表格、行间公式,各自占一段
if len(para)>0: # 先把之前的段落加入到结果中
else: # 其他,图片、表格、行间公式,各自占一段
if len(para) > 0: # 先把之前的段落加入到结果中
paras.append(para)
para = []
paras.append([line]) # 再把当前行加入到结果中。当前行为行间公式、图、表等。
paras.append(
[line]
) # 再把当前行加入到结果中。当前行为行间公式、图、表等。
para = []
if len(para)>0:
if len(para) > 0:
paras.append(para)
para = []
list_info.append(layout_list_info)
layout_paras.append(paras)
paras = []
return layout_paras, list_info
def __connect_list_inter_layout(layout_paras, new_layout_bbox, layout_list_info, page_num, lang):
"""
如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO 因为没有区分列表和段落,所以这个方法暂时不实现。
根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。
"""
if len(layout_paras)==0 or len(layout_list_info)==0: # 0的时候最后的return 会出错
def __connect_list_inter_layout(
layout_paras, new_layout_bbox, layout_list_info, page_num, lang
):
"""如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO
因为没有区分列表和段落,所以这个方法暂时不实现。
根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。"""
if (
len(layout_paras) == 0 or len(layout_list_info) == 0
): # 0的时候最后的return 会出错
return layout_paras, [False, False]
for i in range(1, len(layout_paras)):
pre_layout_list_info = layout_list_info[i-1]
pre_layout_list_info = layout_list_info[i - 1]
next_layout_list_info = layout_list_info[i]
pre_last_para = layout_paras[i-1][-1]
pre_last_para = layout_paras[i - 1][-1]
next_paras = layout_paras[i]
next_first_para = next_paras[0]
if pre_layout_list_info[1] and not next_layout_list_info[0]: # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
logger.info(f"连接page {page_num} 内的list")
if (
pre_layout_list_info[1] and not next_layout_list_info[0]
): # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
logger.info(f'连接page {page_num} 内的list')
# 向layout_paras[i] 寻找开头具有相同缩进的连续的行
may_list_lines = []
for j in range(len(next_paras)):
line = next_paras[j]
if len(line)==1: # 只可能是一行,多行情况再需要分析了
if line[0]['bbox'][0] > __find_layout_bbox_by_line(line[0]['bbox'], new_layout_bbox)[0]:
if len(line) == 1: # 只可能是一行,多行情况再需要分析了
if (
line[0]['bbox'][0]
> __find_layout_bbox_by_line(line[0]['bbox'], new_layout_bbox)[
0
]
):
may_list_lines.append(line[0])
else:
break
else:
break
# 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。
if len(may_list_lines)>0 and len(set([x['bbox'][0] for x in may_list_lines]))==1:
if (
len(may_list_lines) > 0
and len(set([x['bbox'][0] for x in may_list_lines])) == 1
):
pre_last_para.extend(may_list_lines)
layout_paras[i] = layout_paras[i][len(may_list_lines):]
return layout_paras, [layout_list_info[0][0], layout_list_info[-1][1]] # 同时还返回了这个页面级别的开头、结尾是不是列表的信息
def __connect_list_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox, pre_page_list_info, next_page_list_info, page_num, lang):
"""
如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO 因为没有区分列表和段落,所以这个方法暂时不实现。
根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。
"""
if len(pre_page_paras)==0 or len(next_page_paras)==0: # 0的时候最后的return 会出错
layout_paras[i] = layout_paras[i][len(may_list_lines) :]
return layout_paras, [
layout_list_info[0][0],
layout_list_info[-1][1],
] # 同时还返回了这个页面级别的开头、结尾是不是列表的信息
def __connect_list_inter_page(
pre_page_paras,
next_page_paras,
pre_page_layout_bbox,
next_page_layout_bbox,
pre_page_list_info,
next_page_list_info,
page_num,
lang,
):
"""如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO
因为没有区分列表和段落,所以这个方法暂时不实现。
根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。"""
if (
len(pre_page_paras) == 0 or len(next_page_paras) == 0
): # 0的时候最后的return 会出错
return False
if pre_page_list_info[1] and not next_page_list_info[0]: # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
logger.info(f"连接page {page_num} 内的list")
if (
pre_page_list_info[1] and not next_page_list_info[0]
): # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
logger.info(f'连接page {page_num} 内的list')
# 向layout_paras[i] 寻找开头具有相同缩进的连续的行
may_list_lines = []
for j in range(len(next_page_paras[0])):
line = next_page_paras[0][j]
if len(line)==1: # 只可能是一行,多行情况再需要分析了
if line[0]['bbox'][0] > __find_layout_bbox_by_line(line[0]['bbox'], next_page_layout_bbox)[0]:
if len(line) == 1: # 只可能是一行,多行情况再需要分析了
if (
line[0]['bbox'][0]
> __find_layout_bbox_by_line(
line[0]['bbox'], next_page_layout_bbox
)[0]
):
may_list_lines.append(line[0])
else:
break
else:
break
# 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。
if len(may_list_lines)>0 and len(set([x['bbox'][0] for x in may_list_lines]))==1:
if (
len(may_list_lines) > 0
and len(set([x['bbox'][0] for x in may_list_lines])) == 1
):
pre_page_paras[-1].append(may_list_lines)
next_page_paras[0] = next_page_paras[0][len(may_list_lines):]
next_page_paras[0] = next_page_paras[0][len(may_list_lines) :]
return True
return False
def __find_layout_bbox_by_line(line_bbox, layout_bboxes):
"""
根据line找到所在的layout
"""
"""根据line找到所在的layout."""
for layout in layout_bboxes:
if is_in_layout(line_bbox, layout):
return layout
......@@ -420,48 +482,74 @@ def __connect_para_inter_layoutbox(layout_paras, new_layout_bbox, lang):
"""
connected_layout_paras = []
if len(layout_paras)==0:
if len(layout_paras) == 0:
return connected_layout_paras
connected_layout_paras.append(layout_paras[0])
for i in range(1, len(layout_paras)):
try:
if len(layout_paras[i])==0 or len(layout_paras[i-1])==0: # TODO 考虑连接问题,
if (
len(layout_paras[i]) == 0 or len(layout_paras[i - 1]) == 0
): # TODO 考虑连接问题,
continue
pre_last_line = layout_paras[i-1][-1][-1]
pre_last_line = layout_paras[i - 1][-1][-1]
next_first_line = layout_paras[i][0][0]
except Exception as e:
logger.error(f"page layout {i} has no line")
except Exception:
logger.error(f'page layout {i} has no line')
continue
pre_last_line_text = ''.join([__get_span_text(span) for span in pre_last_line['spans']])
pre_last_line_text = ''.join(
[__get_span_text(span) for span in pre_last_line['spans']]
)
pre_last_line_type = pre_last_line['spans'][-1]['type']
next_first_line_text = ''.join([__get_span_text(span) for span in next_first_line['spans']])
next_first_line_text = ''.join(
[__get_span_text(span) for span in next_first_line['spans']]
)
next_first_line_type = next_first_line['spans'][0]['type']
if pre_last_line_type not in [TEXT, INLINE_EQUATION] or next_first_line_type not in [TEXT, INLINE_EQUATION]:
if pre_last_line_type not in [
TEXT,
INLINE_EQUATION,
] or next_first_line_type not in [TEXT, INLINE_EQUATION]:
connected_layout_paras.append(layout_paras[i])
continue
pre_x2_max = __find_layout_bbox_by_line(pre_last_line['bbox'], new_layout_bbox)[2]
next_x0_min = __find_layout_bbox_by_line(next_first_line['bbox'], new_layout_bbox)[0]
pre_x2_max = __find_layout_bbox_by_line(pre_last_line['bbox'], new_layout_bbox)[
2
]
next_x0_min = __find_layout_bbox_by_line(
next_first_line['bbox'], new_layout_bbox
)[0]
pre_last_line_text = pre_last_line_text.strip()
next_first_line_text = next_first_line_text.strip()
if pre_last_line['bbox'][2] == pre_x2_max and pre_last_line_text[-1] not in LINE_STOP_FLAG and next_first_line['bbox'][0]==next_x0_min: # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
if (
pre_last_line['bbox'][2] == pre_x2_max
and pre_last_line_text[-1] not in LINE_STOP_FLAG
and next_first_line['bbox'][0] == next_x0_min
): # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
"""连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。"""
connected_layout_paras[-1][-1].extend(layout_paras[i][0])
layout_paras[i].pop(0) # 删除后一个layout的第一个段落, 因为他已经被合并到前一个layout的最后一个段落了。
if len(layout_paras[i])==0:
layout_paras[i].pop(
0
) # 删除后一个layout的第一个段落, 因为他已经被合并到前一个layout的最后一个段落了。
if len(layout_paras[i]) == 0:
layout_paras.pop(i)
else:
connected_layout_paras.append(layout_paras[i])
else:
else:
"""连接段落条件不成立,将前一个layout的段落加入到结果中。"""
connected_layout_paras.append(layout_paras[i])
return connected_layout_paras
def __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox, page_num, lang):
def __connect_para_inter_page(
pre_page_paras,
next_page_paras,
pre_page_layout_bbox,
next_page_layout_bbox,
page_num,
lang,
):
"""
连接起来相邻两个页面的段落——前一个页面最后一个段落和后一个页面的第一个段落。
是否可以连接的条件:
......@@ -469,34 +557,60 @@ def __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_b
2. 后一个页面的第一个段落第一行没有空白开头。
"""
# 有的页面可能压根没有文字
if len(pre_page_paras)==0 or len(next_page_paras)==0 or len(pre_page_paras[0])==0 or len(next_page_paras[0])==0: # TODO [[]]为什么出现在pre_page_paras里?
if (
len(pre_page_paras) == 0
or len(next_page_paras) == 0
or len(pre_page_paras[0]) == 0
or len(next_page_paras[0]) == 0
): # TODO [[]]为什么出现在pre_page_paras里?
return False
pre_last_para = pre_page_paras[-1][-1]
next_first_para = next_page_paras[0][0]
pre_last_line = pre_last_para[-1]
next_first_line = next_first_para[0]
pre_last_line_text = ''.join([__get_span_text(span) for span in pre_last_line['spans']])
pre_last_line_text = ''.join(
[__get_span_text(span) for span in pre_last_line['spans']]
)
pre_last_line_type = pre_last_line['spans'][-1]['type']
next_first_line_text = ''.join([__get_span_text(span) for span in next_first_line['spans']])
next_first_line_text = ''.join(
[__get_span_text(span) for span in next_first_line['spans']]
)
next_first_line_type = next_first_line['spans'][0]['type']
if pre_last_line_type not in [TEXT, INLINE_EQUATION] or next_first_line_type not in [TEXT, INLINE_EQUATION]: # TODO,真的要做好,要考虑跨table, image, 行间的情况
if pre_last_line_type not in [
TEXT,
INLINE_EQUATION,
] or next_first_line_type not in [
TEXT,
INLINE_EQUATION,
]: # TODO,真的要做好,要考虑跨table, image, 行间的情况
# 不是文本,不连接
return False
pre_x2_max = __find_layout_bbox_by_line(pre_last_line['bbox'], pre_page_layout_bbox)[2]
next_x0_min = __find_layout_bbox_by_line(next_first_line['bbox'], next_page_layout_bbox)[0]
pre_x2_max = __find_layout_bbox_by_line(
pre_last_line['bbox'], pre_page_layout_bbox
)[2]
next_x0_min = __find_layout_bbox_by_line(
next_first_line['bbox'], next_page_layout_bbox
)[0]
pre_last_line_text = pre_last_line_text.strip()
next_first_line_text = next_first_line_text.strip()
if pre_last_line['bbox'][2] == pre_x2_max and pre_last_line_text[-1] not in LINE_STOP_FLAG and next_first_line['bbox'][0]==next_x0_min: # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
if (
pre_last_line['bbox'][2] == pre_x2_max
and pre_last_line_text[-1] not in LINE_STOP_FLAG
and next_first_line['bbox'][0] == next_x0_min
): # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
"""连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。"""
pre_last_para.extend(next_first_para)
next_page_paras[0].pop(0) # 删除后一个页面的第一个段落, 因为他已经被合并到前一个页面的最后一个段落了。
next_page_paras[0].pop(
0
) # 删除后一个页面的第一个段落, 因为他已经被合并到前一个页面的最后一个段落了。
return True
else:
return False
def find_consecutive_true_regions(input_array):
start_index = None # 连续True区域的起始索引
regions = [] # 用于保存所有连续True区域的起始和结束索引
......@@ -509,77 +623,103 @@ def find_consecutive_true_regions(input_array):
# 如果我们找到了一个False值,并且当前在连续True区域中
elif not input_array[i] and start_index is not None:
# 如果连续True区域长度大于1,那么将其添加到结果列表中
if i - start_index > 1:
regions.append((start_index, i-1))
if i - start_index > 1:
regions.append((start_index, i - 1))
start_index = None # 重置起始索引
# 如果最后一个元素是True,那么需要将最后一个连续True区域加入到结果列表中
if start_index is not None and len(input_array) - start_index > 1:
regions.append((start_index, len(input_array)-1))
regions.append((start_index, len(input_array) - 1))
return regions
def __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang, debug_mode):
def __connect_middle_align_text(
page_paras, new_layout_bbox, page_num, lang, debug_mode
):
"""
找出来中间对齐的连续单行文本,如果连续行高度相同,那么合并为一个段落。
一个line居中的条件是:
1. 水平中心点跨越layout的中心点。
2. 左右两侧都有空白
"""
for layout_i, layout_para in enumerate(page_paras):
layout_box = new_layout_bbox[layout_i]
single_line_paras_tag = []
for i in range(len(layout_para)):
single_line_paras_tag.append(len(layout_para[i])==1 and layout_para[i][0]['spans'][0]['type']==TEXT)
single_line_paras_tag.append(
len(layout_para[i]) == 1
and layout_para[i][0]['spans'][0]['type'] == TEXT
)
"""找出来连续的单行文本,如果连续行高度相同,那么合并为一个段落。"""
consecutive_single_line_indices = find_consecutive_true_regions(single_line_paras_tag)
if len(consecutive_single_line_indices)>0:
consecutive_single_line_indices = find_consecutive_true_regions(
single_line_paras_tag
)
if len(consecutive_single_line_indices) > 0:
index_offset = 0
"""检查这些行是否是高度相同的,居中的"""
for start, end in consecutive_single_line_indices:
start += index_offset
end += index_offset
line_hi = np.array([line[0]['bbox'][3]-line[0]['bbox'][1] for line in layout_para[start:end+1]])
first_line_text = ''.join([__get_span_text(span) for span in layout_para[start][0]['spans']])
if "Table" in first_line_text or "Figure" in first_line_text:
line_hi = np.array(
[
line[0]['bbox'][3] - line[0]['bbox'][1]
for line in layout_para[start : end + 1]
]
)
first_line_text = ''.join(
[__get_span_text(span) for span in layout_para[start][0]['spans']]
)
if 'Table' in first_line_text or 'Figure' in first_line_text:
pass
if debug_mode:
logger.debug(line_hi.std())
if line_hi.std()<2:
"""行高度相同,那么判断是否居中"""
all_left_x0 = [line[0]['bbox'][0] for line in layout_para[start:end+1]]
all_right_x1 = [line[0]['bbox'][2] for line in layout_para[start:end+1]]
if line_hi.std() < 2:
"""行高度相同,那么判断是否居中."""
all_left_x0 = [
line[0]['bbox'][0] for line in layout_para[start : end + 1]
]
all_right_x1 = [
line[0]['bbox'][2] for line in layout_para[start : end + 1]
]
layout_center = (layout_box[0] + layout_box[2]) / 2
if all([x0 < layout_center < x1 for x0, x1 in zip(all_left_x0, all_right_x1)]) \
and not all([x0==layout_box[0] for x0 in all_left_x0]) \
and not all([x1==layout_box[2] for x1 in all_right_x1]):
merge_para = [l[0] for l in layout_para[start:end+1]]
para_text = ''.join([__get_span_text(span) for line in merge_para for span in line['spans']])
if (
all(
[
x0 < layout_center < x1
for x0, x1 in zip(all_left_x0, all_right_x1)
]
)
and not all([x0 == layout_box[0] for x0 in all_left_x0])
and not all([x1 == layout_box[2] for x1 in all_right_x1])
):
merge_para = [l[0] for l in layout_para[start : end + 1]] # noqa: E741
para_text = ''.join(
[
__get_span_text(span)
for line in merge_para
for span in line['spans']
]
)
if debug_mode:
logger.debug(para_text)
layout_para[start:end+1] = [merge_para]
index_offset -= end-start
layout_para[start : end + 1] = [merge_para]
index_offset -= end - start
return
def __merge_signle_list_text(page_paras, new_layout_bbox, page_num, lang):
"""
找出来连续的单行文本,如果首行顶格,接下来的几个单行段落缩进对齐,那么合并为一个段落。
"""
"""找出来连续的单行文本,如果首行顶格,接下来的几个单行段落缩进对齐,那么合并为一个段落。"""
pass
def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang):
"""
根据line和layout情况进行分段
先实现一个根据行末尾特征分段的简单方法。
"""
"""根据line和layout情况进行分段 先实现一个根据行末尾特征分段的简单方法。"""
"""
算法思路:
1. 扫描layout里每一行,找出来行尾距离layout有边界有一定距离的行。
......@@ -587,52 +727,73 @@ def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang):
3. 参照上述行尾特征进行分段。
4. 图、表,目前独占一行,不考虑分段。
"""
if page_num==343:
if page_num == 343:
pass
lines_group = __group_line_by_layout(blocks, layout_bboxes, lang) # block内分段
layout_paras, layout_list_info = __split_para_in_layoutbox(lines_group, new_layout_bbox, lang) # layout内分段
layout_paras2, page_list_info = __connect_list_inter_layout(layout_paras, new_layout_bbox, layout_list_info, page_num, lang) # layout之间连接列表段落
connected_layout_paras = __connect_para_inter_layoutbox(layout_paras2, new_layout_bbox, lang) # layout间链接段落
lines_group = __group_line_by_layout(blocks, layout_bboxes, lang) # block内分段
layout_paras, layout_list_info = __split_para_in_layoutbox(
lines_group, new_layout_bbox, lang
) # layout内分段
layout_paras2, page_list_info = __connect_list_inter_layout(
layout_paras, new_layout_bbox, layout_list_info, page_num, lang
) # layout之间连接列表段落
connected_layout_paras = __connect_para_inter_layoutbox(
layout_paras2, new_layout_bbox, lang
) # layout间链接段落
return connected_layout_paras, page_list_info
def para_split(pdf_info_dict, debug_mode, lang="en"):
"""
根据line和layout情况进行分段
"""
new_layout_of_pages = [] # 数组的数组,每个元素是一个页面的layoutS
all_page_list_info = [] # 保存每个页面开头和结尾是否是列表
def para_split(pdf_info_dict, debug_mode, lang='en'):
"""根据line和layout情况进行分段."""
new_layout_of_pages = [] # 数组的数组,每个元素是一个页面的layoutS
all_page_list_info = [] # 保存每个页面开头和结尾是否是列表
for page_num, page in pdf_info_dict.items():
blocks = page['preproc_blocks']
layout_bboxes = page['layout_bboxes']
new_layout_bbox = __common_pre_proc(blocks, layout_bboxes)
new_layout_of_pages.append(new_layout_bbox)
splited_blocks, page_list_info = __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang)
splited_blocks, page_list_info = __do_split_page(
blocks, layout_bboxes, new_layout_bbox, page_num, lang
)
all_page_list_info.append(page_list_info)
page['para_blocks'] = splited_blocks
"""连接页面与页面之间的可能合并的段落"""
pdf_infos = list(pdf_info_dict.values())
for page_num, page in enumerate(pdf_info_dict.values()):
if page_num==0:
if page_num == 0:
continue
pre_page_paras = pdf_infos[page_num-1]['para_blocks']
pre_page_paras = pdf_infos[page_num - 1]['para_blocks']
next_page_paras = pdf_infos[page_num]['para_blocks']
pre_page_layout_bbox = new_layout_of_pages[page_num-1]
pre_page_layout_bbox = new_layout_of_pages[page_num - 1]
next_page_layout_bbox = new_layout_of_pages[page_num]
is_conn = __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox, page_num, lang)
is_conn = __connect_para_inter_page(
pre_page_paras,
next_page_paras,
pre_page_layout_bbox,
next_page_layout_bbox,
page_num,
lang,
)
if debug_mode:
if is_conn:
logger.info(f"连接了第{page_num-1}页和第{page_num}页的段落")
is_list_conn = __connect_list_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox, all_page_list_info[page_num-1], all_page_list_info[page_num], page_num, lang)
logger.info(f'连接了第{page_num-1}页和第{page_num}页的段落')
is_list_conn = __connect_list_inter_page(
pre_page_paras,
next_page_paras,
pre_page_layout_bbox,
next_page_layout_bbox,
all_page_list_info[page_num - 1],
all_page_list_info[page_num],
page_num,
lang,
)
if debug_mode:
if is_list_conn:
logger.info(f"连接了第{page_num-1}页和第{page_num}页的列表段落")
logger.info(f'连接了第{page_num-1}页和第{page_num}页的列表段落')
"""接下来可能会漏掉一些特别的一些可以合并的内容,对他们进行段落连接
1. 正文中有时出现一个行顶格,接下来几行缩进的情况。
2. 居中的一些连续单行,如果高度相同,那么可能是一个段落。
......@@ -640,5 +801,7 @@ def para_split(pdf_info_dict, debug_mode, lang="en"):
for page_num, page in enumerate(pdf_info_dict.values()):
page_paras = page['para_blocks']
new_layout_bbox = new_layout_of_pages[page_num]
__connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang, debug_mode=debug_mode)
__connect_middle_align_text(
page_paras, new_layout_bbox, page_num, lang, debug_mode=debug_mode
)
__merge_signle_list_text(page_paras, new_layout_bbox, page_num, lang)
import copy
import re
from sklearn.cluster import DBSCAN
import numpy as np
from loguru import logger
import re
from magic_pdf.libs.boxbase import _is_in_or_part_overlap_with_area_ratio as is_in_layout
from magic_pdf.libs.ocr_content_type import ContentType, BlockType
from magic_pdf.model.magic_model import MagicModel
from magic_pdf.libs.Constants import *
from sklearn.cluster import DBSCAN
from magic_pdf.config.constants import * # noqa: F403
from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.libs.boxbase import \
_is_in_or_part_overlap_with_area_ratio as is_in_layout
LINE_STOP_FLAG = ['.', '!', '?', '。', '!', '?', ":", ":", ")", ")", ";"]
LINE_STOP_FLAG = ['.', '!', '?', '。', '!', '?', ':', ':', ')', ')', ';']
INLINE_EQUATION = ContentType.InlineEquation
INTERLINE_EQUATION = ContentType.InterlineEquation
TEXT = ContentType.Text
......@@ -36,7 +37,9 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
ones_indices = []
i = 0
while i < len(lst): # Loop through the entire list
if lst[i] == 1: # If we encounter a '1', we might be at the start of a pattern
if (
lst[i] == 1
): # If we encounter a '1', we might be at the start of a pattern
start = i
ones_in_this_interval = [i]
i += 1
......@@ -46,7 +49,10 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
ones_in_this_interval.append(i)
i += 1
if len(ones_in_this_interval) > 1 or (
start < len(lst) - 1 and ones_in_this_interval and lst[start + 1] in [2, 3]):
start < len(lst) - 1
and ones_in_this_interval
and lst[start + 1] in [2, 3]
):
indices.append((start, i - 1))
ones_indices.append(ones_in_this_interval)
else:
......@@ -65,7 +71,12 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
while i < len(lst) and lst[i] in [2, 3]:
i += 1
# 验证下一个序列是否符合条件
if i < len(lst) - 1 and lst[i] == 1 and lst[i + 1] in [2, 3] and lst[i - 1] in [2, 3]:
if (
i < len(lst) - 1
and lst[i] == 1
and lst[i + 1] in [2, 3]
and lst[i - 1] in [2, 3]
):
while i < len(lst) and lst[i] in [1, 2, 3]:
if lst[i] == 1:
ones_in_this_interval.append(i)
......@@ -114,7 +125,7 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
"""
if len(lines) > 0:
x_map_tag_dict, min_x_tag = cluster_line_x(lines)
for l in lines:
for l in lines: # noqa: E741
span_text = __get_span_text(l['spans'][0])
if not span_text:
line_fea_encode.append(0)
......@@ -142,28 +153,26 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
list_indice, list_start_idx = find_repeating_patterns2(line_fea_encode)
if len(list_indice) > 0:
if debug_able:
logger.info(f"发现了列表,列表行数:{list_indice}{list_start_idx}")
logger.info(f'发现了列表,列表行数:{list_indice}{list_start_idx}')
# TODO check一下这个特列表里缩进的行左侧是不是对齐的。
segments = []
for start, end in list_indice:
for i in range(start, end + 1):
if i > 0:
if line_fea_encode[i] == 4:
if debug_able:
logger.info(f"列表行的第{i}行不是顶格的")
logger.info(f'列表行的第{i}行不是顶格的')
break
else:
if debug_able:
logger.info(f"列表行的第{start}到第{end}行是列表")
logger.info(f'列表行的第{start}到第{end}行是列表')
return split_indices(total_lines, list_indice), list_start_idx
def cluster_line_x(lines: list) -> dict:
"""
对一个block内所有lines的bbox的x0聚类
"""
"""对一个block内所有lines的bbox的x0聚类."""
min_distance = 5
min_sample = 1
x0_lst = np.array([[round(line['bbox'][0]), 0] for line in lines])
......@@ -171,14 +180,16 @@ def cluster_line_x(lines: list) -> dict:
x0_uniq_label = np.unique(x0_clusters.labels_)
# x1_lst = np.array([[line['bbox'][2], 0] for line in lines])
x0_2_new_val = {} # 存储旧值对应的新值映射
min_x0 = round(lines[0]["bbox"][0])
min_x0 = round(lines[0]['bbox'][0])
for label in x0_uniq_label:
if label == -1:
continue
x0_index_of_label = np.where(x0_clusters.labels_ == label)
x0_raw_val = x0_lst[x0_index_of_label][:, 0]
x0_new_val = np.min(x0_lst[x0_index_of_label][:, 0])
x0_2_new_val.update({round(raw_val): round(x0_new_val) for raw_val in x0_raw_val})
x0_2_new_val.update(
{round(raw_val): round(x0_new_val) for raw_val in x0_raw_val}
)
if x0_new_val < min_x0:
min_x0 = x0_new_val
return x0_2_new_val, min_x0
......@@ -193,27 +204,41 @@ def if_match_reference_list(text: str) -> bool:
def __valign_lines(blocks, layout_bboxes):
"""
在一个layoutbox内对齐行的左侧和右侧。
扫描行的左侧和右侧,如果x0, x1差距不超过一个阈值,就强行对齐到所处layout的左右两侧(和layout有一段距离)。
3是个经验值,TODO,计算得来,可以设置为1.5个正文字符。
"""
"""在一个layoutbox内对齐行的左侧和右侧。 扫描行的左侧和右侧,如果x0,
x1差距不超过一个阈值,就强行对齐到所处layout的左右两侧(和layout有一段距离)。
3是个经验值,TODO,计算得来,可以设置为1.5个正文字符。"""
min_distance = 3
min_sample = 2
new_layout_bboxes = []
# add bbox_fs for para split calculation
for block in blocks:
block["bbox_fs"] = copy.deepcopy(block["bbox"])
block['bbox_fs'] = copy.deepcopy(block['bbox'])
for layout_box in layout_bboxes:
blocks_in_layoutbox = [b for b in blocks if
b["type"] == BlockType.Text and is_in_layout(b['bbox'], layout_box['layout_bbox'])]
if len(blocks_in_layoutbox) == 0 or len(blocks_in_layoutbox[0]["lines"]) == 0:
blocks_in_layoutbox = [
b
for b in blocks
if b['type'] == BlockType.Text
and is_in_layout(b['bbox'], layout_box['layout_bbox'])
]
if len(blocks_in_layoutbox) == 0 or len(blocks_in_layoutbox[0]['lines']) == 0:
new_layout_bboxes.append(layout_box['layout_bbox'])
continue
x0_lst = np.array([[line['bbox'][0], 0] for block in blocks_in_layoutbox for line in block['lines']])
x1_lst = np.array([[line['bbox'][2], 0] for block in blocks_in_layoutbox for line in block['lines']])
x0_lst = np.array(
[
[line['bbox'][0], 0]
for block in blocks_in_layoutbox
for line in block['lines']
]
)
x1_lst = np.array(
[
[line['bbox'][2], 0]
for block in blocks_in_layoutbox
for line in block['lines']
]
)
x0_clusters = DBSCAN(eps=min_distance, min_samples=min_sample).fit(x0_lst)
x1_clusters = DBSCAN(eps=min_distance, min_samples=min_sample).fit(x1_lst)
x0_uniq_label = np.unique(x0_clusters.labels_)
......@@ -248,11 +273,13 @@ def __valign_lines(blocks, layout_bboxes):
# 由于修改了block里的line长度,现在需要重新计算block的bbox
for block in blocks_in_layoutbox:
if len(block["lines"]) > 0:
block['bbox_fs'] = [min([line['bbox'][0] for line in block['lines']]),
min([line['bbox'][1] for line in block['lines']]),
max([line['bbox'][2] for line in block['lines']]),
max([line['bbox'][3] for line in block['lines']])]
if len(block['lines']) > 0:
block['bbox_fs'] = [
min([line['bbox'][0] for line in block['lines']]),
min([line['bbox'][1] for line in block['lines']]),
max([line['bbox'][2] for line in block['lines']]),
max([line['bbox'][3] for line in block['lines']]),
]
"""新计算layout的bbox,因为block的bbox变了。"""
layout_x0 = min([block['bbox_fs'][0] for block in blocks_in_layoutbox])
layout_y0 = min([block['bbox_fs'][1] for block in blocks_in_layoutbox])
......@@ -264,18 +291,19 @@ def __valign_lines(blocks, layout_bboxes):
def __align_text_in_layout(blocks, layout_bboxes):
"""
由于ocr出来的line,有时候会在前后有一段空白,这个时候需要对文本进行对齐,超出的部分被layout左右侧截断。
"""
"""由于ocr出来的line,有时候会在前后有一段空白,这个时候需要对文本进行对齐,超出的部分被layout左右侧截断。"""
for layout in layout_bboxes:
lb = layout['layout_bbox']
blocks_in_layoutbox = [block for block in blocks if
block["type"] == BlockType.Text and is_in_layout(block['bbox'], lb)]
blocks_in_layoutbox = [
block
for block in blocks
if block['type'] == BlockType.Text and is_in_layout(block['bbox'], lb)
]
if len(blocks_in_layoutbox) == 0:
continue
for block in blocks_in_layoutbox:
for line in block.get("lines", []):
for line in block.get('lines', []):
x0, x1 = line['bbox'][0], line['bbox'][2]
if x0 < lb[0]:
line['bbox'][0] = lb[0]
......@@ -284,9 +312,7 @@ def __align_text_in_layout(blocks, layout_bboxes):
def __common_pre_proc(blocks, layout_bboxes):
"""
不分语言的,对文本进行预处理
"""
"""不分语言的,对文本进行预处理."""
# __add_line_period(blocks, layout_bboxes)
__align_text_in_layout(blocks, layout_bboxes)
aligned_layout_bboxes = __valign_lines(blocks, layout_bboxes)
......@@ -295,32 +321,30 @@ def __common_pre_proc(blocks, layout_bboxes):
def __pre_proc_zh_blocks(blocks, layout_bboxes):
"""
对中文文本进行分段预处理
"""
"""对中文文本进行分段预处理."""
pass
def __pre_proc_en_blocks(blocks, layout_bboxes):
"""
对英文文本进行分段预处理
"""
"""对英文文本进行分段预处理."""
pass
def __group_line_by_layout(blocks, layout_bboxes):
"""
每个layout内的行进行聚合
"""
"""每个layout内的行进行聚合."""
# 因为只是一个block一行目前, 一个block就是一个段落
blocks_group = []
for lyout in layout_bboxes:
blocks_in_layout = [block for block in blocks if is_in_layout(block.get('bbox_fs', None), lyout['layout_bbox'])]
blocks_in_layout = [
block
for block in blocks
if is_in_layout(block.get('bbox_fs', None), lyout['layout_bbox'])
]
blocks_group.append(blocks_in_layout)
return blocks_group
def __split_para_in_layoutbox(blocks_group, new_layout_bbox, lang="en"):
def __split_para_in_layoutbox(blocks_group, new_layout_bbox, lang='en'):
"""
lines_group 进行行分段——layout内部进行分段。lines_group内每个元素是一个Layoutbox内的所有行。
1. 先计算每个group的左右边界。
......@@ -336,17 +360,20 @@ def __split_para_in_layoutbox(blocks_group, new_layout_bbox, lang="en"):
if len(blocks) == 0:
list_info.append([False, False])
continue
if blocks[0]["type"] != BlockType.Text and blocks[-1]["type"] != BlockType.Text:
if blocks[0]['type'] != BlockType.Text and blocks[-1]['type'] != BlockType.Text:
list_info.append([False, False])
continue
if blocks[0]["type"] != BlockType.Text:
if blocks[0]['type'] != BlockType.Text:
is_start_list = False
if blocks[-1]["type"] != BlockType.Text:
if blocks[-1]['type'] != BlockType.Text:
is_end_list = False
lines = [line for block in blocks if
block["type"] == BlockType.Text for line in
block['lines']]
lines = [
line
for block in blocks
if block['type'] == BlockType.Text
for line in block['lines']
]
total_lines = len(lines)
if total_lines == 1 or total_lines == 0:
list_info.append([False, False])
......@@ -359,7 +386,9 @@ def __split_para_in_layoutbox(blocks_group, new_layout_bbox, lang="en"):
2. 左对齐的列表块(其特点是左侧顶格的行数小于等于非顶格的行数,非定格首字母会有小写,顶格90%是大写。并且左侧顶格行数大于1,大于1是为了这种模式连续出现才能称之为列表)
这样的文本块,顶格的为一个段落开头,紧随其后非顶格的行属于这个段落。
"""
text_segments, list_start_line = __detect_list_lines(lines, new_layout_bbox, lang)
text_segments, list_start_line = __detect_list_lines(
lines, new_layout_bbox, lang
)
"""根据list_range,把lines分成几个部分
"""
......@@ -368,10 +397,17 @@ def __split_para_in_layoutbox(blocks_group, new_layout_bbox, lang="en"):
for i in range(0, len(list_start)):
index = list_start[i] - 1
if index >= 0:
if "content" in lines[index]["spans"][-1] and lines[index]["spans"][-1].get('type', '') not in [
ContentType.InlineEquation, ContentType.InterlineEquation]:
lines[index]["spans"][-1]["content"] += '\n\n'
layout_list_info = [False, False] # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾
if 'content' in lines[index]['spans'][-1] and lines[index][
'spans'
][-1].get('type', '') not in [
ContentType.InlineEquation,
ContentType.InterlineEquation,
]:
lines[index]['spans'][-1]['content'] += '\n\n'
layout_list_info = [
False,
False,
] # 这个layout最后是不是列表,记录每一个layout里是不是列表开头,列表结尾
for content_type, start, end in text_segments:
if content_type == 'list':
if start == 0 and is_start_list is None:
......@@ -388,8 +424,7 @@ def __split_para_lines(lines: list, text_blocks: list) -> list:
other_paras = []
text_lines = []
for line in lines:
spans_types = [span["type"] for span in line]
spans_types = [span['type'] for span in line]
if ContentType.Table in spans_types:
other_paras.append([line])
continue
......@@ -402,20 +437,22 @@ def __split_para_lines(lines: list, text_blocks: list) -> list:
text_lines.append(line)
for block in text_blocks:
block_bbox = block["bbox"]
block_bbox = block['bbox']
para = []
for line in text_lines:
bbox = line["bbox"]
bbox = line['bbox']
if is_in_layout(bbox, block_bbox):
para.append(line)
if len(para) > 0:
text_paras.append(para)
paras = other_paras.extend(text_paras)
paras_sorted = sorted(paras, key=lambda x: x[0]["bbox"][1])
paras_sorted = sorted(paras, key=lambda x: x[0]['bbox'][1])
return paras_sorted
def __connect_list_inter_layout(blocks_group, new_layout_bbox, layout_list_info, page_num, lang):
def __connect_list_inter_layout(
blocks_group, new_layout_bbox, layout_list_info, page_num, lang
):
global debug_able
"""
如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO 因为没有区分列表和段落,所以这个方法暂时不实现。
......@@ -429,74 +466,108 @@ def __connect_list_inter_layout(blocks_group, new_layout_bbox, layout_list_info,
continue
pre_layout_list_info = layout_list_info[i - 1]
next_layout_list_info = layout_list_info[i]
pre_last_para = blocks_group[i - 1][-1].get("lines", [])
pre_last_para = blocks_group[i - 1][-1].get('lines', [])
next_paras = blocks_group[i]
next_first_para = next_paras[0]
if pre_layout_list_info[1] and not next_layout_list_info[0] and next_first_para[
"type"] == BlockType.Text: # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
if (
pre_layout_list_info[1]
and not next_layout_list_info[0]
and next_first_para['type'] == BlockType.Text
): # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
if debug_able:
logger.info(f"连接page {page_num} 内的list")
logger.info(f'连接page {page_num} 内的list')
# 向layout_paras[i] 寻找开头具有相同缩进的连续的行
may_list_lines = []
lines = next_first_para.get("lines", [])
lines = next_first_para.get('lines', [])
for line in lines:
if line['bbox'][0] > __find_layout_bbox_by_line(line['bbox'], new_layout_bbox)[0]:
if (
line['bbox'][0]
> __find_layout_bbox_by_line(line['bbox'], new_layout_bbox)[0]
):
may_list_lines.append(line)
else:
break
# 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。
if len(may_list_lines) > 0 and len(set([x['bbox'][0] for x in may_list_lines])) == 1:
if (
len(may_list_lines) > 0
and len(set([x['bbox'][0] for x in may_list_lines])) == 1
):
pre_last_para.extend(may_list_lines)
next_first_para["lines"] = next_first_para["lines"][len(may_list_lines):]
return blocks_group, [layout_list_info[0][0], layout_list_info[-1][1]] # 同时还返回了这个页面级别的开头、结尾是不是列表的信息
def __connect_list_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox,
pre_page_list_info, next_page_list_info, page_num, lang):
"""
如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO 因为没有区分列表和段落,所以这个方法暂时不实现。
根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。
"""
if len(pre_page_paras) == 0 or len(next_page_paras) == 0: # 0的时候最后的return 会出错
next_first_para['lines'] = next_first_para['lines'][
len(may_list_lines) :
]
return blocks_group, [
layout_list_info[0][0],
layout_list_info[-1][1],
] # 同时还返回了这个页面级别的开头、结尾是不是列表的信息
def __connect_list_inter_page(
pre_page_paras,
next_page_paras,
pre_page_layout_bbox,
next_page_layout_bbox,
pre_page_list_info,
next_page_list_info,
page_num,
lang,
):
"""如果上个layout的最后一个段落是列表,下一个layout的第一个段落也是列表,那么将他们连接起来。 TODO
因为没有区分列表和段落,所以这个方法暂时不实现。
根据layout_list_info判断是不是列表。,下个layout的第一个段如果不是列表,那么看他们是否有几行都有相同的缩进。"""
if (
len(pre_page_paras) == 0 or len(next_page_paras) == 0
): # 0的时候最后的return 会出错
return False
if len(pre_page_paras[-1]) == 0 or len(next_page_paras[0]) == 0:
return False
if pre_page_paras[-1][-1]["type"] != BlockType.Text or next_page_paras[0][0]["type"] != BlockType.Text:
if (
pre_page_paras[-1][-1]['type'] != BlockType.Text
or next_page_paras[0][0]['type'] != BlockType.Text
):
return False
if pre_page_list_info[1] and not next_page_list_info[0]: # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
if (
pre_page_list_info[1] and not next_page_list_info[0]
): # 前一个是列表结尾,后一个是非列表开头,此时检测是否有相同的缩进
if debug_able:
logger.info(f"连接page {page_num} 内的list")
logger.info(f'连接page {page_num} 内的list')
# 向layout_paras[i] 寻找开头具有相同缩进的连续的行
may_list_lines = []
next_page_first_para = next_page_paras[0][0]
if next_page_first_para["type"] == BlockType.Text:
lines = next_page_first_para["lines"]
if next_page_first_para['type'] == BlockType.Text:
lines = next_page_first_para['lines']
for line in lines:
if line['bbox'][0] > __find_layout_bbox_by_line(line['bbox'], next_page_layout_bbox)[0]:
if (
line['bbox'][0]
> __find_layout_bbox_by_line(line['bbox'], next_page_layout_bbox)[0]
):
may_list_lines.append(line)
else:
break
# 如果这些行的缩进是相等的,那么连到上一个layout的最后一个段落上。
if len(may_list_lines) > 0 and len(set([x['bbox'][0] for x in may_list_lines])) == 1:
if (
len(may_list_lines) > 0
and len(set([x['bbox'][0] for x in may_list_lines])) == 1
):
# pre_page_paras[-1].append(may_list_lines)
# 下一页合并到上一页最后一段,打一个cross_page的标签
for line in may_list_lines:
for span in line["spans"]:
span[CROSS_PAGE] = True
pre_page_paras[-1][-1]["lines"].extend(may_list_lines)
next_page_first_para["lines"] = next_page_first_para["lines"][len(may_list_lines):]
for span in line['spans']:
span[CROSS_PAGE] = True # noqa: F405
pre_page_paras[-1][-1]['lines'].extend(may_list_lines)
next_page_first_para['lines'] = next_page_first_para['lines'][
len(may_list_lines) :
]
return True
return False
def __find_layout_bbox_by_line(line_bbox, layout_bboxes):
"""
根据line找到所在的layout
"""
"""根据line找到所在的layout."""
for layout in layout_bboxes:
if is_in_layout(line_bbox, layout):
return layout
......@@ -525,39 +596,59 @@ def __connect_para_inter_layoutbox(blocks_group, new_layout_bbox):
connected_layout_blocks.append(blocks_group[i])
continue
# text类型的段才需要考虑layout间的合并
if blocks_group[i - 1][-1]["type"] != BlockType.Text or blocks_group[i][0]["type"] != BlockType.Text:
if (
blocks_group[i - 1][-1]['type'] != BlockType.Text
or blocks_group[i][0]['type'] != BlockType.Text
):
connected_layout_blocks.append(blocks_group[i])
continue
if len(blocks_group[i - 1][-1]["lines"]) == 0 or len(blocks_group[i][0]["lines"]) == 0:
if (
len(blocks_group[i - 1][-1]['lines']) == 0
or len(blocks_group[i][0]['lines']) == 0
):
connected_layout_blocks.append(blocks_group[i])
continue
pre_last_line = blocks_group[i - 1][-1]["lines"][-1]
next_first_line = blocks_group[i][0]["lines"][0]
except Exception as e:
logger.error(f"page layout {i} has no line")
pre_last_line = blocks_group[i - 1][-1]['lines'][-1]
next_first_line = blocks_group[i][0]['lines'][0]
except Exception:
logger.error(f'page layout {i} has no line')
continue
pre_last_line_text = ''.join([__get_span_text(span) for span in pre_last_line['spans']])
pre_last_line_text = ''.join(
[__get_span_text(span) for span in pre_last_line['spans']]
)
pre_last_line_type = pre_last_line['spans'][-1]['type']
next_first_line_text = ''.join([__get_span_text(span) for span in next_first_line['spans']])
next_first_line_text = ''.join(
[__get_span_text(span) for span in next_first_line['spans']]
)
next_first_line_type = next_first_line['spans'][0]['type']
if pre_last_line_type not in [TEXT, INLINE_EQUATION] or next_first_line_type not in [TEXT, INLINE_EQUATION]:
if pre_last_line_type not in [
TEXT,
INLINE_EQUATION,
] or next_first_line_type not in [TEXT, INLINE_EQUATION]:
connected_layout_blocks.append(blocks_group[i])
continue
pre_layout = __find_layout_bbox_by_line(pre_last_line['bbox'], new_layout_bbox)
next_layout = __find_layout_bbox_by_line(next_first_line['bbox'], new_layout_bbox)
next_layout = __find_layout_bbox_by_line(
next_first_line['bbox'], new_layout_bbox
)
pre_x2_max = pre_layout[2] if pre_layout else -1
next_x0_min = next_layout[0] if next_layout else -1
pre_last_line_text = pre_last_line_text.strip()
next_first_line_text = next_first_line_text.strip()
if pre_last_line['bbox'][2] == pre_x2_max and pre_last_line_text and pre_last_line_text[
-1] not in LINE_STOP_FLAG and \
next_first_line['bbox'][0] == next_x0_min: # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
if (
pre_last_line['bbox'][2] == pre_x2_max
and pre_last_line_text
and pre_last_line_text[-1] not in LINE_STOP_FLAG
and next_first_line['bbox'][0] == next_x0_min
): # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
"""连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。"""
connected_layout_blocks[-1][-1]["lines"].extend(blocks_group[i][0]["lines"])
blocks_group[i][0]["lines"] = [] # 删除后一个layout第一个段落中的lines,因为他已经被合并到前一个layout的最后一个段落了
blocks_group[i][0][LINES_DELETED] = True
connected_layout_blocks[-1][-1]['lines'].extend(blocks_group[i][0]['lines'])
blocks_group[i][0][
'lines'
] = [] # 删除后一个layout第一个段落中的lines,因为他已经被合并到前一个layout的最后一个段落了
blocks_group[i][0][LINES_DELETED] = True # noqa: F405
# if len(layout_paras[i]) == 0:
# layout_paras.pop(i)
# else:
......@@ -569,8 +660,14 @@ def __connect_para_inter_layoutbox(blocks_group, new_layout_bbox):
return connected_layout_blocks
def __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox, next_page_layout_bbox, page_num,
lang):
def __connect_para_inter_page(
pre_page_paras,
next_page_paras,
pre_page_layout_bbox,
next_page_layout_bbox,
page_num,
lang,
):
"""
连接起来相邻两个页面的段落——前一个页面最后一个段落和后一个页面的第一个段落。
是否可以连接的条件:
......@@ -578,33 +675,53 @@ def __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_b
2. 后一个页面的第一个段落第一行没有空白开头。
"""
# 有的页面可能压根没有文字
if len(pre_page_paras) == 0 or len(next_page_paras) == 0 or len(pre_page_paras[0]) == 0 or len(
next_page_paras[0]) == 0: # TODO [[]]为什么出现在pre_page_paras里?
if (
len(pre_page_paras) == 0
or len(next_page_paras) == 0
or len(pre_page_paras[0]) == 0
or len(next_page_paras[0]) == 0
): # TODO [[]]为什么出现在pre_page_paras里?
return False
pre_last_block = pre_page_paras[-1][-1]
next_first_block = next_page_paras[0][0]
if pre_last_block["type"] != BlockType.Text or next_first_block["type"] != BlockType.Text:
if (
pre_last_block['type'] != BlockType.Text
or next_first_block['type'] != BlockType.Text
):
return False
if len(pre_last_block["lines"]) == 0 or len(next_first_block["lines"]) == 0:
if len(pre_last_block['lines']) == 0 or len(next_first_block['lines']) == 0:
return False
pre_last_para = pre_last_block["lines"]
next_first_para = next_first_block["lines"]
pre_last_para = pre_last_block['lines']
next_first_para = next_first_block['lines']
pre_last_line = pre_last_para[-1]
next_first_line = next_first_para[0]
pre_last_line_text = ''.join([__get_span_text(span) for span in pre_last_line['spans']])
pre_last_line_text = ''.join(
[__get_span_text(span) for span in pre_last_line['spans']]
)
pre_last_line_type = pre_last_line['spans'][-1]['type']
next_first_line_text = ''.join([__get_span_text(span) for span in next_first_line['spans']])
next_first_line_text = ''.join(
[__get_span_text(span) for span in next_first_line['spans']]
)
next_first_line_type = next_first_line['spans'][0]['type']
if pre_last_line_type not in [TEXT, INLINE_EQUATION] or next_first_line_type not in [TEXT,
INLINE_EQUATION]: # TODO,真的要做好,要考虑跨table, image, 行间的情况
if pre_last_line_type not in [
TEXT,
INLINE_EQUATION,
] or next_first_line_type not in [
TEXT,
INLINE_EQUATION,
]: # TODO,真的要做好,要考虑跨table, image, 行间的情况
# 不是文本,不连接
return False
pre_x2_max_bbox = __find_layout_bbox_by_line(pre_last_line['bbox'], pre_page_layout_bbox)
pre_x2_max_bbox = __find_layout_bbox_by_line(
pre_last_line['bbox'], pre_page_layout_bbox
)
if not pre_x2_max_bbox:
return False
next_x0_min_bbox = __find_layout_bbox_by_line(next_first_line['bbox'], next_page_layout_bbox)
next_x0_min_bbox = __find_layout_bbox_by_line(
next_first_line['bbox'], next_page_layout_bbox
)
if not next_x0_min_bbox:
return False
......@@ -613,18 +730,21 @@ def __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_b
pre_last_line_text = pre_last_line_text.strip()
next_first_line_text = next_first_line_text.strip()
if pre_last_line['bbox'][2] == pre_x2_max and pre_last_line_text[-1] not in LINE_STOP_FLAG and \
next_first_line['bbox'][0] == next_x0_min: # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
if (
pre_last_line['bbox'][2] == pre_x2_max
and pre_last_line_text[-1] not in LINE_STOP_FLAG
and next_first_line['bbox'][0] == next_x0_min
): # 前面一行沾满了整个行,并且没有结尾符号.下一行没有空白开头。
"""连接段落条件成立,将前一个layout的段落和后一个layout的段落连接。"""
# 下一页合并到上一页最后一段,打一个cross_page的标签
for line in next_first_para:
for span in line["spans"]:
span[CROSS_PAGE] = True
for span in line['spans']:
span[CROSS_PAGE] = True # noqa: F405
pre_last_para.extend(next_first_para)
# next_page_paras[0].pop(0) # 删除后一个页面的第一个段落, 因为他已经被合并到前一个页面的最后一个段落了。
next_page_paras[0][0]["lines"] = []
next_page_paras[0][0][LINES_DELETED] = True
next_page_paras[0][0]['lines'] = []
next_page_paras[0][0][LINES_DELETED] = True # noqa: F405
return True
else:
return False
......@@ -667,38 +787,73 @@ def __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang):
single_line_paras_tag = []
for i in range(len(layout_para)):
# single_line_paras_tag.append(len(layout_para[i]) == 1 and layout_para[i][0]['spans'][0]['type'] == TEXT)
single_line_paras_tag.append(layout_para[i]['type'] == BlockType.Text and len(layout_para[i]["lines"]) == 1)
single_line_paras_tag.append(
layout_para[i]['type'] == BlockType.Text
and len(layout_para[i]['lines']) == 1
)
"""找出来连续的单行文本,如果连续行高度相同,那么合并为一个段落。"""
consecutive_single_line_indices = find_consecutive_true_regions(single_line_paras_tag)
consecutive_single_line_indices = find_consecutive_true_regions(
single_line_paras_tag
)
if len(consecutive_single_line_indices) > 0:
"""检查这些行是否是高度相同的,居中的"""
"""检查这些行是否是高度相同的,居中的."""
for start, end in consecutive_single_line_indices:
# start += index_offset
# end += index_offset
line_hi = np.array([block["lines"][0]['bbox'][3] - block["lines"][0]['bbox'][1] for block in
layout_para[start:end + 1]])
first_line_text = ''.join([__get_span_text(span) for span in layout_para[start]["lines"][0]['spans']])
if "Table" in first_line_text or "Figure" in first_line_text:
line_hi = np.array(
[
block['lines'][0]['bbox'][3] - block['lines'][0]['bbox'][1]
for block in layout_para[start : end + 1]
]
)
first_line_text = ''.join(
[
__get_span_text(span)
for span in layout_para[start]['lines'][0]['spans']
]
)
if 'Table' in first_line_text or 'Figure' in first_line_text:
pass
if debug_able:
logger.info(line_hi.std())
if line_hi.std() < 2:
"""行高度相同,那么判断是否居中"""
all_left_x0 = [block["lines"][0]['bbox'][0] for block in layout_para[start:end + 1]]
all_right_x1 = [block["lines"][0]['bbox'][2] for block in layout_para[start:end + 1]]
"""行高度相同,那么判断是否居中."""
all_left_x0 = [
block['lines'][0]['bbox'][0]
for block in layout_para[start : end + 1]
]
all_right_x1 = [
block['lines'][0]['bbox'][2]
for block in layout_para[start : end + 1]
]
layout_center = (layout_box[0] + layout_box[2]) / 2
if all([x0 < layout_center < x1 for x0, x1 in zip(all_left_x0, all_right_x1)]) \
and not all([x0 == layout_box[0] for x0 in all_left_x0]) \
and not all([x1 == layout_box[2] for x1 in all_right_x1]):
merge_para = [block["lines"][0] for block in layout_para[start:end + 1]]
para_text = ''.join([__get_span_text(span) for line in merge_para for span in line['spans']])
if (
all(
[
x0 < layout_center < x1
for x0, x1 in zip(all_left_x0, all_right_x1)
]
)
and not all([x0 == layout_box[0] for x0 in all_left_x0])
and not all([x1 == layout_box[2] for x1 in all_right_x1])
):
merge_para = [
block['lines'][0] for block in layout_para[start : end + 1]
]
para_text = ''.join(
[
__get_span_text(span)
for line in merge_para
for span in line['spans']
]
)
if debug_able:
logger.info(para_text)
layout_para[start]["lines"] = merge_para
layout_para[start]['lines'] = merge_para
for i_para in range(start + 1, end + 1):
layout_para[i_para]["lines"] = []
layout_para[i_para][LINES_DELETED] = True
layout_para[i_para]['lines'] = []
layout_para[i_para][LINES_DELETED] = True # noqa: F405
# layout_para[start:end + 1] = [merge_para]
# index_offset -= end - start
......@@ -707,18 +862,13 @@ def __connect_middle_align_text(page_paras, new_layout_bbox, page_num, lang):
def __merge_signle_list_text(page_paras, new_layout_bbox, page_num, lang):
"""
找出来连续的单行文本,如果首行顶格,接下来的几个单行段落缩进对齐,那么合并为一个段落。
"""
"""找出来连续的单行文本,如果首行顶格,接下来的几个单行段落缩进对齐,那么合并为一个段落。"""
pass
def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang):
"""
根据line和layout情况进行分段
先实现一个根据行末尾特征分段的简单方法。
"""
"""根据line和layout情况进行分段 先实现一个根据行末尾特征分段的简单方法。"""
"""
算法思路:
1. 扫描layout里每一行,找出来行尾距离layout有边界有一定距离的行。
......@@ -727,15 +877,20 @@ def __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang):
4. 图、表,目前独占一行,不考虑分段。
"""
blocks_group = __group_line_by_layout(blocks, layout_bboxes) # block内分段
layout_list_info = __split_para_in_layoutbox(blocks_group, new_layout_bbox, lang) # layout内分段
blocks_group, page_list_info = __connect_list_inter_layout(blocks_group, new_layout_bbox, layout_list_info,
page_num, lang) # layout之间连接列表段落
connected_layout_blocks = __connect_para_inter_layoutbox(blocks_group, new_layout_bbox) # layout间链接段落
layout_list_info = __split_para_in_layoutbox(
blocks_group, new_layout_bbox, lang
) # layout内分段
blocks_group, page_list_info = __connect_list_inter_layout(
blocks_group, new_layout_bbox, layout_list_info, page_num, lang
) # layout之间连接列表段落
connected_layout_blocks = __connect_para_inter_layoutbox(
blocks_group, new_layout_bbox
) # layout间链接段落
return connected_layout_blocks, page_list_info
def para_split(pdf_info_dict, debug_mode, lang="en"):
def para_split(pdf_info_dict, debug_mode, lang='en'):
global debug_able
debug_able = debug_mode
new_layout_of_pages = [] # 数组的数组,每个元素是一个页面的layoutS
......@@ -745,7 +900,9 @@ def para_split(pdf_info_dict, debug_mode, lang="en"):
layout_bboxes = page['layout_bboxes']
new_layout_bbox = __common_pre_proc(blocks, layout_bboxes)
new_layout_of_pages.append(new_layout_bbox)
splited_blocks, page_list_info = __do_split_page(blocks, layout_bboxes, new_layout_bbox, page_num, lang)
splited_blocks, page_list_info = __do_split_page(
blocks, layout_bboxes, new_layout_bbox, page_num, lang
)
all_page_list_info.append(page_list_info)
page['para_blocks'] = splited_blocks
......@@ -759,18 +916,31 @@ def para_split(pdf_info_dict, debug_mode, lang="en"):
pre_page_layout_bbox = new_layout_of_pages[page_num - 1]
next_page_layout_bbox = new_layout_of_pages[page_num]
is_conn = __connect_para_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox,
next_page_layout_bbox, page_num, lang)
is_conn = __connect_para_inter_page(
pre_page_paras,
next_page_paras,
pre_page_layout_bbox,
next_page_layout_bbox,
page_num,
lang,
)
if debug_able:
if is_conn:
logger.info(f"连接了第{page_num - 1}页和第{page_num}页的段落")
is_list_conn = __connect_list_inter_page(pre_page_paras, next_page_paras, pre_page_layout_bbox,
next_page_layout_bbox, all_page_list_info[page_num - 1],
all_page_list_info[page_num], page_num, lang)
logger.info(f'连接了第{page_num - 1}页和第{page_num}页的段落')
is_list_conn = __connect_list_inter_page(
pre_page_paras,
next_page_paras,
pre_page_layout_bbox,
next_page_layout_bbox,
all_page_list_info[page_num - 1],
all_page_list_info[page_num],
page_num,
lang,
)
if debug_able:
if is_list_conn:
logger.info(f"连接了第{page_num - 1}页和第{page_num}页的列表段落")
logger.info(f'连接了第{page_num - 1}页和第{page_num}页的列表段落')
"""接下来可能会漏掉一些特别的一些可以合并的内容,对他们进行段落连接
1. 正文中有时出现一个行顶格,接下来几行缩进的情况。
......@@ -786,4 +956,4 @@ def para_split(pdf_info_dict, debug_mode, lang="en"):
for page_num, page in enumerate(pdf_info_dict.values()):
page_paras = page['para_blocks']
page_blocks = [block for layout in page_paras for block in layout]
page["para_blocks"] = page_blocks
page['para_blocks'] = page_blocks
import copy
from loguru import logger
from magic_pdf.libs.Constants import LINES_DELETED, CROSS_PAGE
from magic_pdf.libs.ocr_content_type import BlockType, ContentType
LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';')
from magic_pdf.config.constants import CROSS_PAGE, LINES_DELETED
from magic_pdf.config.ocr_content_type import BlockType, ContentType
LINE_STOP_FLAG = (
'.',
'!',
'?',
'。',
'!',
'?',
')',
')',
'"',
'”',
':',
':',
';',
';',
)
LIST_END_FLAG = ('.', '。', ';', ';')
class ListLineTag:
IS_LIST_START_LINE = "is_list_start_line"
IS_LIST_END_LINE = "is_list_end_line"
IS_LIST_START_LINE = 'is_list_start_line'
IS_LIST_END_LINE = 'is_list_end_line'
def __process_blocks(blocks):
......@@ -27,12 +40,14 @@ def __process_blocks(blocks):
# 如果当前块是 text 类型
if current_block['type'] == 'text':
current_block["bbox_fs"] = copy.deepcopy(current_block["bbox"])
if 'lines' in current_block and len(current_block["lines"]) > 0:
current_block['bbox_fs'] = [min([line['bbox'][0] for line in current_block['lines']]),
min([line['bbox'][1] for line in current_block['lines']]),
max([line['bbox'][2] for line in current_block['lines']]),
max([line['bbox'][3] for line in current_block['lines']])]
current_block['bbox_fs'] = copy.deepcopy(current_block['bbox'])
if 'lines' in current_block and len(current_block['lines']) > 0:
current_block['bbox_fs'] = [
min([line['bbox'][0] for line in current_block['lines']]),
min([line['bbox'][1] for line in current_block['lines']]),
max([line['bbox'][2] for line in current_block['lines']]),
max([line['bbox'][3] for line in current_block['lines']]),
]
current_group.append(current_block)
# 检查下一个块是否存在
......@@ -64,6 +79,7 @@ def __is_list_or_index_block(block):
line_height = first_line['bbox'][3] - first_line['bbox'][1]
block_weight = block['bbox_fs'][2] - block['bbox_fs'][0]
block_height = block['bbox_fs'][3] - block['bbox_fs'][1]
page_weight, page_height = block['page_size']
left_close_num = 0
left_not_close_num = 0
......@@ -75,10 +91,17 @@ def __is_list_or_index_block(block):
multiple_para_flag = False
last_line = block['lines'][-1]
if page_weight == 0:
block_weight_radio = 0
else:
block_weight_radio = block_weight / page_weight
# logger.info(f"block_weight_radio: {block_weight_radio}")
# 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格)
if (first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2 and
abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2 and
block['bbox_fs'][2] - last_line['bbox'][2] > line_height
if (
first_line['bbox'][0] - block['bbox_fs'][0] > line_height / 2
and abs(last_line['bbox'][0] - block['bbox_fs'][0]) < line_height / 2
and block['bbox_fs'][2] - last_line['bbox'][2] > line_height
):
multiple_para_flag = True
......@@ -86,14 +109,14 @@ def __is_list_or_index_block(block):
line_mid_x = (line['bbox'][0] + line['bbox'][2]) / 2
block_mid_x = (block['bbox_fs'][0] + block['bbox_fs'][2]) / 2
if (
line['bbox'][0] - block['bbox_fs'][0] > 0.8 * line_height and
block['bbox_fs'][2] - line['bbox'][2] > 0.8 * line_height
line['bbox'][0] - block['bbox_fs'][0] > 0.8 * line_height
and block['bbox_fs'][2] - line['bbox'][2] > 0.8 * line_height
):
external_sides_not_close_num += 1
if abs(line_mid_x - block_mid_x) < line_height / 2:
center_close_num += 1
line_text = ""
line_text = ''
for span in line['spans']:
span_type = span['type']
......@@ -114,7 +137,12 @@ def __is_list_or_index_block(block):
right_close_num += 1
else:
# 右侧不顶格情况下是否有一段距离,拍脑袋用0.3block宽度做阈值
closed_area = 0.26 * block_weight
# block宽的阈值可以小些,block窄的阈值要大
if block_weight_radio >= 0.5:
closed_area = 0.26 * block_weight
else:
closed_area = 0.36 * block_weight
if block['bbox_fs'][2] - line['bbox'][2] > closed_area:
right_not_close_num += 1
......@@ -136,15 +164,19 @@ def __is_list_or_index_block(block):
if line_text[-1].isdigit():
num_end_count += 1
if num_start_count / len(lines_text_list) >= 0.8 or num_end_count / len(lines_text_list) >= 0.8:
if (
num_start_count / len(lines_text_list) >= 0.8
or num_end_count / len(lines_text_list) >= 0.8
):
line_num_flag = True
if flag_end_count / len(lines_text_list) >= 0.8:
line_end_flag = True
# 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index
if ((left_close_num / len(block['lines']) >= 0.8 or right_close_num / len(block['lines']) >= 0.8)
and line_num_flag
):
if (
left_close_num / len(block['lines']) >= 0.8
or right_close_num / len(block['lines']) >= 0.8
) and line_num_flag:
for line in block['lines']:
line[ListLineTag.IS_LIST_START_LINE] = True
return BlockType.Index
......@@ -152,17 +184,21 @@ def __is_list_or_index_block(block):
# 全部line都居中的特殊list识别,每行都需要换行,特征是多行,且大多数行都前后not_close,每line中点x坐标接近
# 补充条件block的长宽比有要求
elif (
external_sides_not_close_num >= 2 and
center_close_num == len(block['lines']) and
external_sides_not_close_num / len(block['lines']) >= 0.5 and
block_height / block_weight > 0.4
external_sides_not_close_num >= 2
and center_close_num == len(block['lines'])
and external_sides_not_close_num / len(block['lines']) >= 0.5
and block_height / block_weight > 0.4
):
for line in block['lines']:
line[ListLineTag.IS_LIST_START_LINE] = True
return BlockType.List
elif left_close_num >= 2 and (
right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2) and not multiple_para_flag:
elif (
left_close_num >= 2
and (right_not_close_num >= 2 or line_end_flag or left_not_close_num >= 2)
and not multiple_para_flag
# and block_weight_radio > 0.27
):
# 处理一种特殊的没有缩进的list,所有行都贴左边,通过右边的空隙判断是否是item尾
if left_close_num / len(block['lines']) > 0.8:
# 这种是每个item只有一行,且左边都贴边的短item list
......@@ -173,10 +209,15 @@ def __is_list_or_index_block(block):
# 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item
elif line_end_flag:
for i, line in enumerate(block['lines']):
if len(lines_text_list[i]) > 0 and lines_text_list[i][-1] in LIST_END_FLAG:
if (
len(lines_text_list[i]) > 0
and lines_text_list[i][-1] in LIST_END_FLAG
):
line[ListLineTag.IS_LIST_END_LINE] = True
if i + 1 < len(block['lines']):
block['lines'][i + 1][ListLineTag.IS_LIST_START_LINE] = True
block['lines'][i + 1][
ListLineTag.IS_LIST_START_LINE
] = True
# line item基本没有结束标识符,而且也没有缩进,按右侧空隙判断哪些是item end
else:
line_start_flag = False
......@@ -185,7 +226,10 @@ def __is_list_or_index_block(block):
line[ListLineTag.IS_LIST_START_LINE] = True
line_start_flag = False
if abs(block['bbox_fs'][2] - line['bbox'][2]) > 0.1 * block_weight:
if (
abs(block['bbox_fs'][2] - line['bbox'][2])
> 0.1 * block_weight
):
line[ListLineTag.IS_LIST_END_LINE] = True
line_start_flag = True
# 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_FLAG 结尾且数量和start line 一致
......@@ -223,18 +267,25 @@ def __merge_2_text_blocks(block1, block2):
if len(last_line['spans']) > 0:
last_span = last_line['spans'][-1]
line_height = last_line['bbox'][3] - last_line['bbox'][1]
if (abs(block2['bbox_fs'][2] - last_line['bbox'][2]) < line_height and
not last_span['content'].endswith(LINE_STOP_FLAG) and
# 两个block宽度差距超过2倍也不合并
abs(block1_weight - block2_weight) < min_block_weight
):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[LINES_DELETED] = True
if len(first_line['spans']) > 0:
first_span = first_line['spans'][0]
if len(first_span['content']) > 0:
span_start_with_num = first_span['content'][0].isdigit()
if (
abs(block2['bbox_fs'][2] - last_line['bbox'][2])
< line_height
and not last_span['content'].endswith(LINE_STOP_FLAG)
# 两个block宽度差距超过2倍也不合并
and abs(block1_weight - block2_weight) < min_block_weight
and not span_start_with_num
):
if block1['page_num'] != block2['page_num']:
for line in block1['lines']:
for span in line['spans']:
span[CROSS_PAGE] = True
block2['lines'].extend(block1['lines'])
block1['lines'] = []
block1[LINES_DELETED] = True
return block1, block2
......@@ -263,7 +314,6 @@ def __is_list_group(text_blocks_group):
def __para_merge_page(blocks):
page_text_blocks_groups = __process_blocks(blocks)
for text_blocks_group in page_text_blocks_groups:
if len(text_blocks_group) > 0:
# 需要先在合并前对所有block判断是否为list or index block
for block in text_blocks_group:
......@@ -272,7 +322,6 @@ def __para_merge_page(blocks):
# logger.info(f"{block['type']}:{block}")
if len(text_blocks_group) > 1:
# 在合并前判断这个group 是否是一个 list group
is_list_group = __is_list_group(text_blocks_group)
......@@ -284,11 +333,18 @@ def __para_merge_page(blocks):
if i - 1 >= 0:
prev_block = text_blocks_group[i - 1]
if current_block['type'] == 'text' and prev_block['type'] == 'text' and not is_list_group:
if (
current_block['type'] == 'text'
and prev_block['type'] == 'text'
and not is_list_group
):
__merge_2_text_blocks(current_block, prev_block)
elif (
(current_block['type'] == BlockType.List and prev_block['type'] == BlockType.List) or
(current_block['type'] == BlockType.Index and prev_block['type'] == BlockType.Index)
current_block['type'] == BlockType.List
and prev_block['type'] == BlockType.List
) or (
current_block['type'] == BlockType.Index
and prev_block['type'] == BlockType.Index
):
__merge_2_list_blocks(current_block, prev_block)
......@@ -296,12 +352,13 @@ def __para_merge_page(blocks):
continue
def para_split(pdf_info_dict, debug_mode=False):
def para_split(pdf_info_dict):
all_blocks = []
for page_num, page in pdf_info_dict.items():
blocks = copy.deepcopy(page['preproc_blocks'])
for block in blocks:
block['page_num'] = page_num
block['page_size'] = page['page_size']
all_blocks.extend(blocks)
__para_merge_page(all_blocks)
......@@ -317,4 +374,4 @@ if __name__ == '__main__':
# 调用函数
groups = __process_blocks(input_blocks)
for group_index, group in enumerate(groups):
print(f"Group {group_index}: {group}")
print(f'Group {group_index}: {group}')
......@@ -9,6 +9,7 @@ def parse_pdf_by_ocr(pdf_bytes,
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
):
dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset,
......@@ -18,4 +19,5 @@ def parse_pdf_by_ocr(pdf_bytes,
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
lang=lang,
)
......@@ -10,6 +10,7 @@ def parse_pdf_by_txt(
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
):
dataset = PymuDocDataset(pdf_bytes)
return pdf_parse_union(dataset,
......@@ -19,4 +20,5 @@ def parse_pdf_by_txt(
start_page_id=start_page_id,
end_page_id=end_page_id,
debug_mode=debug_mode,
lang=lang,
)
......@@ -2,38 +2,47 @@ import time
from loguru import logger
from magic_pdf.config.drop_reason import DropReason
from magic_pdf.config.ocr_content_type import ContentType
from magic_pdf.layout.layout_sort import (LAYOUT_UNPROC, get_bboxes_layout,
get_columns_cnt_of_layout)
from magic_pdf.libs.commons import fitz, get_delta_time
from magic_pdf.layout.layout_sort import get_bboxes_layout, LAYOUT_UNPROC, get_columns_cnt_of_layout
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.local_math import float_equal
from magic_pdf.libs.ocr_content_type import ContentType
from magic_pdf.model.magic_model import MagicModel
from magic_pdf.para.para_split_v2 import para_split
from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
from magic_pdf.pre_proc.construct_page_dict import \
ocr_construct_page_component_v2
from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
from magic_pdf.pre_proc.equations_replace import remove_chars_in_text_blocks, replace_equations_in_textblock, \
combine_chars_to_pymudict
from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layout_split
from magic_pdf.pre_proc.ocr_dict_merge import sort_blocks_by_layout, fill_spans_in_blocks, fix_block_spans, \
fix_discarded_block
from magic_pdf.pre_proc.ocr_span_list_modify import remove_overlaps_min_spans, get_qa_need_list_v2, \
remove_overlaps_low_confidence_spans
from magic_pdf.pre_proc.resolve_bbox_conflict import check_useful_block_horizontal_overlap
from magic_pdf.pre_proc.equations_replace import (
combine_chars_to_pymudict, remove_chars_in_text_blocks,
replace_equations_in_textblock)
from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
ocr_prepare_bboxes_for_layout_split
from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
fix_block_spans,
fix_discarded_block,
sort_blocks_by_layout)
from magic_pdf.pre_proc.ocr_span_list_modify import (
get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
remove_overlaps_min_spans)
from magic_pdf.pre_proc.resolve_bbox_conflict import \
check_useful_block_horizontal_overlap
def remove_horizontal_overlap_block_which_smaller(all_bboxes):
useful_blocks = []
for bbox in all_bboxes:
useful_blocks.append({
"bbox": bbox[:4]
})
is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = check_useful_block_horizontal_overlap(useful_blocks)
useful_blocks.append({'bbox': bbox[:4]})
is_useful_block_horz_overlap, smaller_bbox, bigger_bbox = (
check_useful_block_horizontal_overlap(useful_blocks)
)
if is_useful_block_horz_overlap:
logger.warning(
f"skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}")
f'skip this page, reason: {DropReason.USEFUL_BLOCK_HOR_OVERLAP}, smaller bbox is {smaller_bbox}, bigger bbox is {bigger_bbox}'
)
for bbox in all_bboxes.copy():
if smaller_bbox == bbox[:4]:
all_bboxes.remove(bbox)
......@@ -41,27 +50,27 @@ def remove_horizontal_overlap_block_which_smaller(all_bboxes):
return is_useful_block_horz_overlap, all_bboxes
def __replace_STX_ETX(text_str:str):
""" Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
def __replace_STX_ETX(text_str: str):
"""Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
Args:
text_str (str): raw text
Args:
text_str (str): raw text
Returns:
_type_: replaced text
Returns:
_type_: replaced text
"""
if text_str:
s = text_str.replace('\u0002', "'")
s = s.replace("\u0003", "'")
s = s.replace('\u0003', "'")
return s
return text_str
def txt_spans_extract(pdf_page, inline_equations, interline_equations):
text_raw_blocks = pdf_page.get_text("dict", flags=fitz.TEXTFLAGS_TEXT)["blocks"]
char_level_text_blocks = pdf_page.get_text("rawdict", flags=fitz.TEXTFLAGS_TEXT)[
"blocks"
text_raw_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
char_level_text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)[
'blocks'
]
text_blocks = combine_chars_to_pymudict(text_raw_blocks, char_level_text_blocks)
text_blocks = replace_equations_in_textblock(
......@@ -71,189 +80,254 @@ def txt_spans_extract(pdf_page, inline_equations, interline_equations):
text_blocks = remove_chars_in_text_blocks(text_blocks)
spans = []
for v in text_blocks:
for line in v["lines"]:
for span in line["spans"]:
bbox = span["bbox"]
for line in v['lines']:
for span in line['spans']:
bbox = span['bbox']
if float_equal(bbox[0], bbox[2]) or float_equal(bbox[1], bbox[3]):
continue
if span.get('type') not in (ContentType.InlineEquation, ContentType.InterlineEquation):
if span.get('type') not in (
ContentType.InlineEquation,
ContentType.InterlineEquation,
):
spans.append(
{
"bbox": list(span["bbox"]),
"content": __replace_STX_ETX(span["text"]),
"type": ContentType.Text,
"score": 1.0,
'bbox': list(span['bbox']),
'content': __replace_STX_ETX(span['text']),
'type': ContentType.Text,
'score': 1.0,
}
)
return spans
def replace_text_span(pymu_spans, ocr_spans):
return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
return list(filter(lambda x: x['type'] != ContentType.Text, ocr_spans)) + pymu_spans
def parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode):
def parse_page_core(
pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
):
need_drop = False
drop_reason = []
'''从magic_model对象中获取后面会用到的区块信息'''
"""从magic_model对象中获取后面会用到的区块信息"""
img_blocks = magic_model.get_imgs(page_id)
table_blocks = magic_model.get_tables(page_id)
discarded_blocks = magic_model.get_discarded(page_id)
text_blocks = magic_model.get_text_blocks(page_id)
title_blocks = magic_model.get_title_blocks(page_id)
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id)
inline_equations, interline_equations, interline_equation_blocks = (
magic_model.get_equations(page_id)
)
page_w, page_h = magic_model.get_page_size(page_id)
spans = magic_model.get_all_spans(page_id)
'''根据parse_mode,构造spans'''
if parse_mode == "txt":
"""根据parse_mode,构造spans"""
if parse_mode == 'txt':
"""ocr 中文本类的 span 用 pymu spans 替换!"""
pymu_spans = txt_spans_extract(
pdf_docs[page_id], inline_equations, interline_equations
)
spans = replace_text_span(pymu_spans, spans)
elif parse_mode == "ocr":
elif parse_mode == 'ocr':
pass
else:
raise Exception("parse_mode must be txt or ocr")
raise Exception('parse_mode must be txt or ocr')
'''删除重叠spans中置信度较低的那些'''
"""删除重叠spans中置信度较低的那些"""
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
'''删除重叠spans中较小的那些'''
"""删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
'''对image和table截图'''
spans = ocr_cut_image_and_table(spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter)
"""对image和table截图"""
spans = ocr_cut_image_and_table(
spans, pdf_docs[page_id], page_id, pdf_bytes_md5, imageWriter
)
'''将所有区块的bbox整理到一起'''
"""将所有区块的bbox整理到一起"""
# interline_equation_blocks参数不够准,后面切换到interline_equations上
interline_equation_blocks = []
if len(interline_equation_blocks) > 0:
all_bboxes, all_discarded_blocks, drop_reasons = ocr_prepare_bboxes_for_layout_split(
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
interline_equation_blocks, page_w, page_h)
all_bboxes, all_discarded_blocks, drop_reasons = (
ocr_prepare_bboxes_for_layout_split(
img_blocks,
table_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equation_blocks,
page_w,
page_h,
)
)
else:
all_bboxes, all_discarded_blocks, drop_reasons = ocr_prepare_bboxes_for_layout_split(
img_blocks, table_blocks, discarded_blocks, text_blocks, title_blocks,
interline_equations, page_w, page_h)
all_bboxes, all_discarded_blocks, drop_reasons = (
ocr_prepare_bboxes_for_layout_split(
img_blocks,
table_blocks,
discarded_blocks,
text_blocks,
title_blocks,
interline_equations,
page_w,
page_h,
)
)
if len(drop_reasons) > 0:
need_drop = True
drop_reason.append(DropReason.OVERLAP_BLOCKS_CAN_NOT_SEPARATION)
'''先处理不需要排版的discarded_blocks'''
discarded_block_with_spans, spans = fill_spans_in_blocks(all_discarded_blocks, spans, 0.4)
"""先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks(
all_discarded_blocks, spans, 0.4
)
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
'''如果当前页面没有bbox则跳过'''
"""如果当前页面没有bbox则跳过"""
if len(all_bboxes) == 0:
logger.warning(f"skip this page, not found useful bbox, page_id: {page_id}")
return ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
[], [], interline_equations, fix_discarded_blocks,
need_drop, drop_reason)
logger.warning(f'skip this page, not found useful bbox, page_id: {page_id}')
return ocr_construct_page_component_v2(
[],
[],
page_id,
page_w,
page_h,
[],
[],
[],
interline_equations,
fix_discarded_blocks,
need_drop,
drop_reason,
)
"""在切分之前,先检查一下bbox是否有左右重叠的情况,如果有,那么就认为这个pdf暂时没有能力处理好,这种左右重叠的情况大概率是由于pdf里的行间公式、表格没有被正确识别出来造成的 """
while True: # 循环检查左右重叠的情况,如果存在就删除掉较小的那个bbox,直到不存在左右重叠的情况
is_useful_block_horz_overlap, all_bboxes = remove_horizontal_overlap_block_which_smaller(all_bboxes)
is_useful_block_horz_overlap, all_bboxes = (
remove_horizontal_overlap_block_which_smaller(all_bboxes)
)
if is_useful_block_horz_overlap:
need_drop = True
drop_reason.append(DropReason.USEFUL_BLOCK_HOR_OVERLAP)
else:
break
'''根据区块信息计算layout'''
"""根据区块信息计算layout"""
page_boundry = [0, 0, page_w, page_h]
layout_bboxes, layout_tree = get_bboxes_layout(all_bboxes, page_boundry, page_id)
if len(text_blocks) > 0 and len(all_bboxes) > 0 and len(layout_bboxes) == 0:
logger.warning(
f"skip this page, page_id: {page_id}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}")
f'skip this page, page_id: {page_id}, reason: {DropReason.CAN_NOT_DETECT_PAGE_LAYOUT}'
)
need_drop = True
drop_reason.append(DropReason.CAN_NOT_DETECT_PAGE_LAYOUT)
"""以下去掉复杂的布局和超过2列的布局"""
if any([lay["layout_label"] == LAYOUT_UNPROC for lay in layout_bboxes]): # 复杂的布局
if any(
[lay['layout_label'] == LAYOUT_UNPROC for lay in layout_bboxes]
): # 复杂的布局
logger.warning(
f"skip this page, page_id: {page_id}, reason: {DropReason.COMPLICATED_LAYOUT}")
f'skip this page, page_id: {page_id}, reason: {DropReason.COMPLICATED_LAYOUT}'
)
need_drop = True
drop_reason.append(DropReason.COMPLICATED_LAYOUT)
layout_column_width = get_columns_cnt_of_layout(layout_tree)
if layout_column_width > 2: # 去掉超过2列的布局pdf
logger.warning(
f"skip this page, page_id: {page_id}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}")
f'skip this page, page_id: {page_id}, reason: {DropReason.TOO_MANY_LAYOUT_COLUMNS}'
)
need_drop = True
drop_reason.append(DropReason.TOO_MANY_LAYOUT_COLUMNS)
'''根据layout顺序,对当前页面所有需要留下的block进行排序'''
"""根据layout顺序,对当前页面所有需要留下的block进行排序"""
sorted_blocks = sort_blocks_by_layout(all_bboxes, layout_bboxes)
'''将span填入排好序的blocks中'''
"""将span填入排好序的blocks中"""
block_with_spans, spans = fill_spans_in_blocks(sorted_blocks, spans, 0.3)
'''对block进行fix操作'''
"""对block进行fix操作"""
fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks)
'''获取QA需要外置的list'''
"""获取QA需要外置的list"""
images, tables, interline_equations = get_qa_need_list_v2(fix_blocks)
'''构造pdf_info_dict'''
page_info = ocr_construct_page_component_v2(fix_blocks, layout_bboxes, page_id, page_w, page_h, layout_tree,
images, tables, interline_equations, fix_discarded_blocks,
need_drop, drop_reason)
"""构造pdf_info_dict"""
page_info = ocr_construct_page_component_v2(
fix_blocks,
layout_bboxes,
page_id,
page_w,
page_h,
layout_tree,
images,
tables,
interline_equations,
fix_discarded_blocks,
need_drop,
drop_reason,
)
return page_info
def pdf_parse_union(pdf_bytes,
model_list,
imageWriter,
parse_mode,
start_page_id=0,
end_page_id=None,
debug_mode=False,
):
def pdf_parse_union(
pdf_bytes,
model_list,
imageWriter,
parse_mode,
start_page_id=0,
end_page_id=None,
debug_mode=False,
):
pdf_bytes_md5 = compute_md5(pdf_bytes)
pdf_docs = fitz.open("pdf", pdf_bytes)
pdf_docs = fitz.open('pdf', pdf_bytes)
'''初始化空的pdf_info_dict'''
"""初始化空的pdf_info_dict"""
pdf_info_dict = {}
'''用model_list和docs对象初始化magic_model'''
"""用model_list和docs对象初始化magic_model"""
magic_model = MagicModel(model_list, pdf_docs)
'''根据输入的起始范围解析pdf'''
"""根据输入的起始范围解析pdf"""
# end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(pdf_docs) - 1
end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else len(pdf_docs) - 1
)
if end_page_id > len(pdf_docs) - 1:
logger.warning("end_page_id is out of range, use pdf_docs length")
logger.warning('end_page_id is out of range, use pdf_docs length')
end_page_id = len(pdf_docs) - 1
'''初始化启动时间'''
"""初始化启动时间"""
start_time = time.time()
for page_id, page in enumerate(pdf_docs):
'''debug时输出每页解析的耗时'''
"""debug时输出每页解析的耗时."""
if debug_mode:
time_now = time.time()
logger.info(
f"page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}"
f'page_id: {page_id}, last_page_cost_time: {get_delta_time(start_time)}'
)
start_time = time_now
'''解析pdf中的每一页'''
"""解析pdf中的每一页"""
if start_page_id <= page_id <= end_page_id:
page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode)
page_info = parse_page_core(
pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
)
else:
page_w = page.rect.width
page_h = page.rect.height
page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
[], [], [], [],
True, "skip page")
pdf_info_dict[f"page_{page_id}"] = page_info
page_info = ocr_construct_page_component_v2(
[], [], page_id, page_w, page_h, [], [], [], [], [], True, 'skip page'
)
pdf_info_dict[f'page_{page_id}'] = page_info
"""分段"""
para_split(pdf_info_dict, debug_mode=debug_mode)
......@@ -261,7 +335,7 @@ def pdf_parse_union(pdf_bytes,
"""dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict)
new_pdf_info_dict = {
"pdf_info": pdf_info_list,
'pdf_info': pdf_info_list,
}
return new_pdf_info_dict
......
......@@ -7,18 +7,32 @@ from typing import List
import torch
from loguru import logger
from magic_pdf.config.drop_reason import DropReason
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.data.dataset import Dataset, PageableData
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.commons import fitz, get_delta_time
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
from magic_pdf.libs.convert_utils import dict_to_list
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.hash_utils import compute_md5
from magic_pdf.libs.local_math import float_equal
from magic_pdf.libs.ocr_content_type import ContentType, BlockType
from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
from magic_pdf.model.magic_model import MagicModel
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try:
import torchtext
if torchtext.__version__ >= "0.18.0":
torchtext.disable_torchtext_deprecation_warning()
except ImportError:
pass
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.para.para_split_v3 import para_split
from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
from magic_pdf.pre_proc.construct_page_dict import \
......@@ -30,8 +44,8 @@ from magic_pdf.pre_proc.equations_replace import (
from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
ocr_prepare_bboxes_for_layout_split_v2
from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
fix_discarded_block,
fix_block_spans_v2)
fix_block_spans_v2,
fix_discarded_block)
from magic_pdf.pre_proc.ocr_span_list_modify import (
get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
remove_overlaps_min_spans)
......@@ -74,7 +88,151 @@ def __replace_STX_ETX(text_str: str):
return text_str
def txt_spans_extract(pdf_page, inline_equations, interline_equations):
def chars_to_content(span):
# # 先给chars按char['bbox']的x坐标排序
# span['chars'] = sorted(span['chars'], key=lambda x: x['bbox'][0])
# 先给chars按char['bbox']的中心点的x坐标排序
span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2)
content = ''
# 求char的平均宽度
if len(span['chars']) == 0:
span['content'] = content
del span['chars']
return
else:
char_width_sum = sum([char['bbox'][2] - char['bbox'][0] for char in span['chars']])
char_avg_width = char_width_sum / len(span['chars'])
for char in span['chars']:
# 如果下一个char的x0和上一个char的x1距离超过一个字符宽度,则需要在中间插入一个空格
if char['bbox'][0] - span['chars'][span['chars'].index(char) - 1]['bbox'][2] > char_avg_width:
content += ' '
content += char['c']
span['content'] = __replace_STX_ETX(content)
del span['chars']
LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';', ']', '】', '}', '}', '>', '》', '、', ',', ',', '-', '—', '–',)
def fill_char_in_spans(spans, all_chars):
for char in all_chars:
for span in spans:
# 判断char是否属于LINE_STOP_FLAG
if char['c'] in LINE_STOP_FLAG:
char_is_line_stop_flag = True
else:
char_is_line_stop_flag = False
if calculate_char_in_span(char['bbox'], span['bbox'], char_is_line_stop_flag):
span['chars'].append(char)
break
for span in spans:
chars_to_content(span)
# 使用鲁棒性更强的中心点坐标判断
def calculate_char_in_span(char_bbox, span_bbox, char_is_line_stop_flag):
char_center_x = (char_bbox[0] + char_bbox[2]) / 2
char_center_y = (char_bbox[1] + char_bbox[3]) / 2
span_center_y = (span_bbox[1] + span_bbox[3]) / 2
span_height = span_bbox[3] - span_bbox[1]
if (
span_bbox[0] < char_center_x < span_bbox[2]
and span_bbox[1] < char_center_y < span_bbox[3]
and abs(char_center_y - span_center_y) < span_height / 4 # 字符的中轴和span的中轴高度差不能超过1/4span高度
):
return True
else:
# 如果char是LINE_STOP_FLAG,就不用中心点判定,换一种方案(左边界在span区域内,高度判定和之前逻辑一致)
# 主要是给结尾符号一个进入span的机会,这个char还应该离span右边界较近
if char_is_line_stop_flag:
if (
(span_bbox[2] - span_height) < char_bbox[0] < span_bbox[2]
and char_center_x > span_bbox[0]
and span_bbox[1] < char_center_y < span_bbox[3]
and abs(char_center_y - span_center_y) < span_height / 4
):
return True
else:
return False
def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
useful_spans = []
unuseful_spans = []
for span in spans:
for block in all_bboxes:
if block[7] in [BlockType.ImageBody, BlockType.TableBody, BlockType.InterlineEquation]:
continue
else:
if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block[0:4]) > 0.5:
useful_spans.append(span)
break
for block in all_discarded_blocks:
if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block[0:4]) > 0.5:
unuseful_spans.append(span)
break
text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
# @todo: 拿到char之后把倾斜角度较大的先删一遍
all_pymu_chars = []
for block in text_blocks:
for line in block['lines']:
for span in line['spans']:
all_pymu_chars.extend(span['chars'])
new_spans = []
for span in useful_spans:
if span['type'] in [ContentType.Text]:
span['chars'] = []
new_spans.append(span)
for span in unuseful_spans:
if span['type'] in [ContentType.Text]:
span['chars'] = []
new_spans.append(span)
fill_char_in_spans(new_spans, all_pymu_chars)
empty_spans = []
for span in new_spans:
if len(span['content']) == 0:
empty_spans.append(span)
if len(empty_spans) > 0:
# 初始化ocr模型
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name="ocr",
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
for span in empty_spans:
spans.remove(span)
# 对span的bbox截图
span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode="cv2")
ocr_res = ocr_model.ocr(span_img, det=False)
# logger.info(f"ocr_res: {ocr_res}")
# logger.info(f"empty_span: {span}")
if ocr_res and len(ocr_res) > 0:
if len(ocr_res[0]) > 0:
ocr_text, ocr_score = ocr_res[0][0]
if ocr_score > 0.5 and len(ocr_text) > 0:
span['content'] = ocr_text
spans.append(span)
return spans
def txt_spans_extract_v1(pdf_page, inline_equations, interline_equations):
text_raw_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
char_level_text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)[
'blocks'
......@@ -164,8 +322,8 @@ class ModelSingleton:
def do_predict(boxes: List[List[int]], model) -> List[int]:
from magic_pdf.model.sub_modules.reading_oreder.layoutreader.helpers import (boxes2inputs, parse_logits,
prepare_inputs)
from magic_pdf.model.sub_modules.reading_oreder.layoutreader.helpers import (
boxes2inputs, parse_logits, prepare_inputs)
inputs = boxes2inputs(boxes)
inputs = prepare_inputs(inputs, model)
......@@ -206,7 +364,9 @@ def cal_block_index(fix_blocks, sorted_bboxes):
del block['real_lines']
import numpy as np
from magic_pdf.model.sub_modules.reading_oreder.layoutreader.xycut import recursive_xy_cut
from magic_pdf.model.sub_modules.reading_oreder.layoutreader.xycut import \
recursive_xy_cut
random_boxes = np.array(block_bboxes)
np.random.shuffle(random_boxes)
......@@ -291,7 +451,7 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
page_line_list.append(bbox)
elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
bbox = block['bbox']
block["real_lines"] = copy.deepcopy(block['lines'])
block['real_lines'] = copy.deepcopy(block['lines'])
lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
block['lines'] = []
for line in lines:
......@@ -462,18 +622,16 @@ def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
def parse_page_core(
page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
):
need_drop = False
drop_reason = []
"""从magic_model对象中获取后面会用到的区块信息"""
# img_blocks = magic_model.get_imgs(page_id)
# table_blocks = magic_model.get_tables(page_id)
img_groups = magic_model.get_imgs_v2(page_id)
table_groups = magic_model.get_tables_v2(page_id)
"""对image和table的区块分组"""
img_body_blocks, img_caption_blocks, img_footnote_blocks = process_groups(
img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
)
......@@ -517,38 +675,20 @@ def parse_page_core(
page_h,
)
"""获取所有的spans信息"""
spans = magic_model.get_all_spans(page_id)
"""根据parse_mode,构造spans"""
if parse_mode == SupportedPdfParseMethod.TXT:
"""ocr 中文本类的 span 用 pymu spans 替换!"""
pymu_spans = txt_spans_extract(page_doc, inline_equations, interline_equations)
spans = replace_text_span(pymu_spans, spans)
elif parse_mode == SupportedPdfParseMethod.OCR:
pass
else:
raise Exception('parse_mode must be txt or ocr')
"""在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
"""顺便删除大水印并保留abandon的span"""
spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
"""删除重叠spans中置信度较低的那些"""
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
"""删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
"""对image和table截图"""
spans = ocr_cut_image_and_table(
spans, page_doc, page_id, pdf_bytes_md5, imageWriter
)
"""先处理不需要排版的discarded_blocks"""
discarded_block_with_spans, spans = fill_spans_in_blocks(
all_discarded_blocks, spans, 0.4
)
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
"""如果当前页面没有bbox则跳过"""
"""如果当前页面没有有效的bbox则跳过"""
if len(all_bboxes) == 0:
logger.warning(f'skip this page, not found useful bbox, page_id: {page_id}')
return ocr_construct_page_component_v2(
......@@ -566,7 +706,32 @@ def parse_page_core(
drop_reason,
)
"""将span填入blocks中"""
"""删除重叠spans中置信度较低的那些"""
spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
"""删除重叠spans中较小的那些"""
spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
"""根据parse_mode,构造spans,主要是文本类的字符填充"""
if parse_mode == SupportedPdfParseMethod.TXT:
"""之前的公式替换方案"""
# pymu_spans = txt_spans_extract_v1(page_doc, inline_equations, interline_equations)
# spans = replace_text_span(pymu_spans, spans)
"""ocr 中文本类的 span 用 pymu spans 替换!"""
spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, lang)
elif parse_mode == SupportedPdfParseMethod.OCR:
pass
else:
raise Exception('parse_mode must be txt or ocr')
"""对image和table截图"""
spans = ocr_cut_image_and_table(
spans, page_doc, page_id, pdf_bytes_md5, imageWriter
)
"""span填充进block"""
block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
"""对block进行fix操作"""
......@@ -616,6 +781,7 @@ def pdf_parse_union(
start_page_id=0,
end_page_id=None,
debug_mode=False,
lang=None,
):
pdf_bytes_md5 = compute_md5(dataset.data_bits())
......@@ -652,7 +818,7 @@ def pdf_parse_union(
"""解析pdf中的每一页"""
if start_page_id <= page_id <= end_page_id:
page_info = parse_page_core(
page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
)
else:
page_info = page.get_page_info()
......@@ -664,7 +830,7 @@ def pdf_parse_union(
pdf_info_dict[f'page_{page_id}'] = page_info
"""分段"""
para_split(pdf_info_dict, debug_mode=debug_mode)
para_split(pdf_info_dict)
"""dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict)
......
from abc import ABC, abstractmethod
from magic_pdf.config.drop_reason import DropReason
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.dict2md.ocr_mkcontent import union_make
from magic_pdf.filter.pdf_classify_by_type import classify
from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
from magic_pdf.libs.MakeContentConfig import MakeMode, DropMode
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.libs.drop_reason import DropReason
from magic_pdf.libs.json_compressor import JsonCompressor
class AbsPipe(ABC):
"""
txt和ocr处理的抽象类
"""
PIP_OCR = "ocr"
PIP_TXT = "txt"
"""txt和ocr处理的抽象类."""
PIP_OCR = 'ocr'
PIP_TXT = 'txt'
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
self.pdf_bytes = pdf_bytes
self.model_list = model_list
......@@ -29,29 +27,23 @@ class AbsPipe(ABC):
self.layout_model = layout_model
self.formula_enable = formula_enable
self.table_enable = table_enable
def get_compress_pdf_mid_data(self):
return JsonCompressor.compress_json(self.pdf_mid_data)
@abstractmethod
def pipe_classify(self):
"""
有状态的分类
"""
"""有状态的分类."""
raise NotImplementedError
@abstractmethod
def pipe_analyze(self):
"""
有状态的跑模型分析
"""
"""有状态的跑模型分析."""
raise NotImplementedError
@abstractmethod
def pipe_parse(self):
"""
有状态的解析
"""
"""有状态的解析."""
raise NotImplementedError
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
......@@ -64,27 +56,25 @@ class AbsPipe(ABC):
@staticmethod
def classify(pdf_bytes: bytes) -> str:
"""
根据pdf的元数据,判断是文本pdf,还是ocr pdf
"""
"""根据pdf的元数据,判断是文本pdf,还是ocr pdf."""
pdf_meta = pdf_meta_scan(pdf_bytes)
if pdf_meta.get("_need_drop", False): # 如果返回了需要丢弃的标志,则抛出异常
if pdf_meta.get('_need_drop', False): # 如果返回了需要丢弃的标志,则抛出异常
raise Exception(f"pdf meta_scan need_drop,reason is {pdf_meta['_drop_reason']}")
else:
is_encrypted = pdf_meta["is_encrypted"]
is_needs_password = pdf_meta["is_needs_password"]
is_encrypted = pdf_meta['is_encrypted']
is_needs_password = pdf_meta['is_needs_password']
if is_encrypted or is_needs_password: # 加密的,需要密码的,没有页面的,都不处理
raise Exception(f"pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}")
raise Exception(f'pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}')
else:
is_text_pdf, results = classify(
pdf_meta["total_page"],
pdf_meta["page_width_pts"],
pdf_meta["page_height_pts"],
pdf_meta["image_info_per_page"],
pdf_meta["text_len_per_page"],
pdf_meta["imgs_per_page"],
pdf_meta["text_layout_per_page"],
pdf_meta["invalid_chars"],
pdf_meta['total_page'],
pdf_meta['page_width_pts'],
pdf_meta['page_height_pts'],
pdf_meta['image_info_per_page'],
pdf_meta['text_len_per_page'],
pdf_meta['imgs_per_page'],
pdf_meta['text_layout_per_page'],
pdf_meta['invalid_chars'],
)
if is_text_pdf:
return AbsPipe.PIP_TXT
......@@ -93,22 +83,16 @@ class AbsPipe(ABC):
@staticmethod
def mk_uni_format(compressed_pdf_mid_data: str, img_buket_path: str, drop_mode=DropMode.WHOLE_PDF) -> list:
"""
根据pdf类型,生成统一格式content_list
"""
"""根据pdf类型,生成统一格式content_list."""
pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
pdf_info_list = pdf_mid_data["pdf_info"]
pdf_info_list = pdf_mid_data['pdf_info']
content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path)
return content_list
@staticmethod
def mk_markdown(compressed_pdf_mid_data: str, img_buket_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD) -> list:
"""
根据pdf类型,markdown
"""
"""根据pdf类型,markdown."""
pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
pdf_info_list = pdf_mid_data["pdf_info"]
pdf_info_list = pdf_mid_data['pdf_info']
md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path)
return md_content
from loguru import logger
from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
from magic_pdf.config.make_content_config import DropMode, MakeMode
from magic_pdf.data.data_reader_writer import DataWriter
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.pipe.AbsPipe import AbsPipe
from magic_pdf.user_api import parse_ocr_pdf
class OCRPipe(AbsPipe):
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False,
start_page_id=0, end_page_id=None, lang=None,
layout_model=None, formula_enable=None, table_enable=None):
super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
......@@ -32,10 +32,10 @@ class OCRPipe(AbsPipe):
def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
logger.info("ocr_pipe mk content list finished")
logger.info('ocr_pipe mk content list finished')
return result
def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
logger.info(f"ocr_pipe mk {md_make_mode} finished")
logger.info(f'ocr_pipe mk {md_make_mode} finished')
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment