Unverified Commit 19f72c23 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1614 from opendatalab/release-1.1.0

Release 1.1.0
parents 4d70b16b adcace44
......@@ -14,7 +14,7 @@
[![Downloads](https://static.pepy.tech/badge/magic-pdf)](https://pepy.tech/project/magic-pdf)
[![Downloads](https://static.pepy.tech/badge/magic-pdf/month)](https://pepy.tech/project/magic-pdf)
[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.org.cn/OpenSourceTools/Extractor?source=github)
[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.net/OpenSourceTools/Extractor?source=github)
[![HuggingFace](https://img.shields.io/badge/Demo_on_HuggingFace-yellow.svg?logo=&labelColor=white)](https://huggingface.co/spaces/opendatalab/MinerU)
[![ModelScope](https://img.shields.io/badge/Demo_on_ModelScope-purple?logo=&labelColor=white)](https://www.modelscope.cn/studios/OpenDataLab/MinerU)
[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/myhloli/3b3a00a4a0a61577b6c30f989092d20d/mineru_demo.ipynb)
......@@ -33,7 +33,7 @@
<a href="https://github.com/opendatalab/PDF-Extract-Kit">PDF-Extract-Kit: High-Quality PDF Extraction Toolkit</a>🔥🔥🔥
<br>
<br>
<a href="https://mineru.org.cn/client?source=github">
<a href="https://mineru.net/client?source=github">
Easier to use: Just grab MinerU Desktop. No coding, no login, just a simple interface and smooth interactions. Enjoy it without any fuss!</a>🚀🚀🚀
</p>
......@@ -41,12 +41,20 @@ Easier to use: Just grab MinerU Desktop. No coding, no login, just a simple inte
<!-- join us -->
<p align="center">
👋 join us on <a href="https://discord.gg/Tdedn9GTXq" target="_blank">Discord</a> and <a href="https://cdn.vansin.top/internlm/mineru.jpg" target="_blank">WeChat</a>
👋 join us on <a href="https://discord.gg/Tdedn9GTXq" target="_blank">Discord</a> and <a href="http://mineru.space/s/V85Yl" target="_blank">WeChat</a>
</p>
</div>
# Changelog
- 2025/01/22 1.1.0 released. In this version we have focused on improving parsing accuracy and efficiency:
- Model capability upgrade (requires re-executing the [model download process](docs/how_to_download_models_en.md) to obtain incremental updates of model files)
- The layout recognition model has been upgraded to the latest `doclayout_yolo(2501)` model, improving layout recognition accuracy.
- The formula parsing model has been upgraded to the latest `unimernet(2501)` model, improving formula recognition accuracy.
- Performance optimization
- On devices that meet certain configuration requirements (16GB+ VRAM), by optimizing resource usage and restructuring the processing pipeline, overall parsing speed has been increased by more than 50%.
- Parsing effect optimization
- Added a new heading classification feature (testing version, enabled by default) to the online demo([mineru.net](https://mineru.net/OpenSourceTools/Extractor)/[huggingface](https://huggingface.co/spaces/opendatalab/MinerU)/[modelscope](https://www.modelscope.cn/studios/OpenDataLab/MinerU)), which supports hierarchical classification of headings, thereby enhancing document structuring.
- 2025/01/10 1.0.1 released. This is our first official release, where we have introduced a completely new API interface and enhanced compatibility through extensive refactoring, as well as a brand new automatic language identification feature:
- New API Interface
- For the data-side API, we have introduced the Dataset class, designed to provide a robust and flexible data processing framework. This framework currently supports a variety of document formats, including images (.jpg and .png), PDFs, Word documents (.doc and .docx), and PowerPoint presentations (.ppt and .pptx). It ensures effective support for data processing tasks ranging from simple to complex.
......@@ -226,7 +234,7 @@ There are three different ways to experience MinerU:
### Online Demo
Stable Version (Stable version verified by QA):
[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.org.cn/OpenSourceTools/Extractor?source=github)
[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.net/OpenSourceTools/Extractor?source=github)
Test Version (Synced with dev branch updates, testing new features):
[![HuggingFace](https://img.shields.io/badge/Demo_on_HuggingFace-yellow.svg?logo=&labelColor=white)](https://huggingface.co/spaces/opendatalab/MinerU)
......@@ -273,6 +281,7 @@ You can modify certain configurations in this file to enable or disable features
},
"table-config": {
"model": "rapid_table", // Default to using "rapid_table", can be switched to "tablemaster" or "struct_eqtable".
"sub_model": "slanet_plus", // When the model is "rapid_table", you can choose a sub_model. The options are "slanet_plus" and "unitable"
"enable": true, // The table recognition feature is enabled by default. If you need to disable it, please change the value here to "false".
"max_time": 400
}
......@@ -356,6 +365,7 @@ TODO
- [x] Reading order based on the model
- [x] Recognition of `index` and `list` in the main text
- [x] Table recognition
- [x] Heading Classification
- [ ] Code block recognition in the main text
- [ ] [Chemical formula recognition](docs/chemical_knowledge_introduction/introduction.pdf)
- [ ] Geometric shape recognition
......@@ -365,7 +375,6 @@ TODO
- Reading order is determined by the model based on the spatial distribution of readable content, and may be out of order in some areas under extremely complex layouts.
- Vertical text is not supported.
- Tables of contents and lists are recognized through rules, and some uncommon list formats may not be recognized.
- Only one level of headings is supported; hierarchical headings are not currently supported.
- Code blocks are not yet supported in the layout model.
- Comic books, art albums, primary school textbooks, and exercises cannot be parsed well.
- Table recognition may result in row/column recognition errors in complex tables.
......
......@@ -14,7 +14,7 @@
[![Downloads](https://static.pepy.tech/badge/magic-pdf)](https://pepy.tech/project/magic-pdf)
[![Downloads](https://static.pepy.tech/badge/magic-pdf/month)](https://pepy.tech/project/magic-pdf)
[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.org.cn/OpenSourceTools/Extractor?source=github)
[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.net/OpenSourceTools/Extractor?source=github)
[![ModelScope](https://img.shields.io/badge/Demo_on_ModelScope-purple?logo=&labelColor=white)](https://www.modelscope.cn/studios/OpenDataLab/MinerU)
[![HuggingFace](https://img.shields.io/badge/Demo_on_HuggingFace-yellow.svg?logo=&labelColor=white)](https://huggingface.co/spaces/opendatalab/MinerU)
[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/myhloli/3b3a00a4a0a61577b6c30f989092d20d/mineru_demo.ipynb)
......@@ -33,19 +33,27 @@
<a href="https://github.com/opendatalab/PDF-Extract-Kit">PDF-Extract-Kit: 高质量PDF解析工具箱</a>🔥🔥🔥
<br>
<br>
<a href="https://mineru.org.cn/client?source=github">更便捷的使用方式:MinerU桌面端。无需编程,无需登录,图形界面,简单交互,畅用无忧。</a>🚀🚀🚀
<a href="https://mineru.net/client?source=github">更便捷的使用方式:MinerU桌面端。无需编程,无需登录,图形界面,简单交互,畅用无忧。</a>🚀🚀🚀
</p>
<!-- join us -->
<p align="center">
👋 join us on <a href="https://discord.gg/Tdedn9GTXq" target="_blank">Discord</a> and <a href="https://cdn.vansin.top/internlm/mineru.jpg" target="_blank">WeChat</a>
👋 join us on <a href="https://discord.gg/Tdedn9GTXq" target="_blank">Discord</a> and <a href="http://mineru.space/s/V85Yl" target="_blank">WeChat</a>
</p>
</div>
# 更新记录
- 2025/01/22 1.1.0 发布,在这个版本我们重点提升了解析的精度与效率:
- 模型能力升级(需重新执行[模型下载流程](docs/how_to_download_models_zh_cn.md)以获得模型文件的增量更新)
- 布局识别模型升级到最新的`doclayout_yolo(2501)`模型,提升了layout识别精度
- 公式解析模型升级到最新的`unimernet(2501)`模型,提升了公式识别精度
- 性能优化
- 在配置满足一定条件(显存16GB+)的设备上,通过优化资源占用和重构处理流水线,整体解析速度提升50%以上
- 解析效果优化
- 在线demo([mineru.net](https://mineru.net/OpenSourceTools/Extractor)/[huggingface](https://huggingface.co/spaces/opendatalab/MinerU)/[modelscope](https://www.modelscope.cn/studios/OpenDataLab/MinerU))上新增标题分级功能(测试版本,默认开启),支持对标题进行分级,提升文档结构化程度
- 2025/01/10 1.0.1 发布,这是我们的第一个正式版本,在这个版本中,我们通过大量重构带来了全新的API接口和更广泛的兼容性,以及全新的自动语言识别功能:
- 全新API接口
- 对于数据侧API,我们引入了Dataset类,旨在提供一个强大而灵活的数据处理框架。该框架当前支持包括图像(.jpg及.png)、PDF、Word(.doc及.docx)、以及PowerPoint(.ppt及.pptx)在内的多种文档格式,确保了从简单到复杂的数据处理任务都能得到有效的支持。
......@@ -227,7 +235,7 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
### 在线体验
稳定版(经过QA验证的稳定版本):
[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.org.cn/OpenSourceTools/Extractor?source=github)
[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.net/OpenSourceTools/Extractor?source=github)
测试版(同步dev分支更新,测试新特性):
......@@ -277,6 +285,7 @@ pip install -U "magic-pdf[full]" --extra-index-url https://wheels.myhloli.com -i
},
"table-config": {
"model": "rapid_table", // 默认使用"rapid_table",可以切换为"tablemaster"和"struct_eqtable"
"sub_model": "slanet_plus", // 当model为"rapid_table"时,可以自选sub_model,可选项为"slanet_plus"和"unitable"
"enable": true, // 表格识别功能默认是开启的,如果需要关闭请修改此处的值为"false"
"max_time": 400
}
......@@ -359,6 +368,7 @@ TODO
- [x] 基于模型的阅读顺序
- [x] 正文中目录、列表识别
- [x] 表格识别
- [x] 标题分级
- [ ] 正文中代码块识别
- [ ] [化学式识别](docs/chemical_knowledge_introduction/introduction.pdf)
- [ ] 几何图形识别
......@@ -368,7 +378,6 @@ TODO
- 阅读顺序基于模型对可阅读内容在空间中的分布进行排序,在极端复杂的排版下可能会部分区域乱序
- 不支持竖排文字
- 目录和列表通过规则进行识别,少部分不常见的列表形式可能无法识别
- 标题只有一级,目前不支持标题分级
- 代码块在layout模型里还没有支持
- 漫画书、艺术图册、小学教材、习题尚不能很好解析
- 表格识别在复杂表格上可能会出现行/列识别错误
......
boto3>=1.28.43
Brotli>=1.1.0
click>=8.1.7
PyMuPDF>=1.24.9
PyMuPDF>=1.24.9,<=1.24.14
loguru>=0.6.0
numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0
......@@ -17,10 +17,9 @@ paddlepaddle==3.0.0b1
struct-eqtable==0.3.2
einops
accelerate
doclayout_yolo==0.0.2
rapidocr-paddle
rapidocr-onnxruntime
rapid_table==0.3.0
doclayout-yolo==0.0.2
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
openai
detectron2
boto3>=1.28.43
Brotli>=1.1.0
click>=8.1.7
PyMuPDF>=1.24.9
PyMuPDF>=1.24.9,<=1.24.14
loguru>=0.6.0
numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0
......@@ -16,10 +16,9 @@ paddleocr==2.7.3
struct-eqtable==0.3.2
einops
accelerate
doclayout_yolo==0.0.2
rapidocr-paddle
rapidocr-onnxruntime
rapid_table==0.3.0
doclayout-yolo==0.0.2
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
openai
detectron2
boto3>=1.28.43
Brotli>=1.1.0
click>=8.1.7
PyMuPDF>=1.24.9
PyMuPDF>=1.24.9,<=1.24.14
loguru>=0.6.0
numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0
......@@ -16,10 +16,9 @@ paddleocr==2.7.3
struct-eqtable==0.3.2
einops
accelerate
doclayout_yolo==0.0.2
rapidocr-paddle
rapidocr-onnxruntime
rapid_table==0.3.0
doclayout-yolo==0.0.2
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
openai
detectron2
......@@ -16,6 +16,7 @@
},
"table-config": {
"model": "rapid_table",
"sub_model": "slanet_plus",
"enable": true,
"max_time": 400
},
......@@ -39,5 +40,5 @@
"enable": false
}
},
"config_version": "1.1.0"
"config_version": "1.1.1"
}
\ No newline at end of file
......@@ -185,10 +185,13 @@ def calculate_iou(bbox1, bbox2):
bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
if any([bbox1_area == 0, bbox2_area == 0]):
return 0
# Compute the intersection over union by taking the intersection area
# and dividing it by the sum of both areas minus the intersection area
iou = intersection_area / float(bbox1_area + bbox2_area -
intersection_area)
iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area)
return iou
......
......@@ -362,12 +362,24 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
for page in pdf_info:
page_line_list = []
for block in page['preproc_blocks']:
if block['type'] in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]:
if block['type'] in [BlockType.Text]:
for line in block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
if block['type'] in [BlockType.Image, BlockType.Table]:
elif block['type'] in [BlockType.Title, BlockType.InterlineEquation]:
if 'virtual_lines' in block:
if len(block['virtual_lines']) > 0 and block['virtual_lines'][0].get('index', None) is not None:
for line in block['virtual_lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
else:
for line in block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
elif block['type'] in [BlockType.Image, BlockType.Table]:
for sub_block in block['blocks']:
if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
......
......@@ -12,12 +12,20 @@ if not os.getenv("FTLANG_CACHE"):
from fast_langdetect import detect_language
def remove_invalid_surrogates(text):
# 移除无效的 UTF-16 代理对
return ''.join(c for c in text if not (0xD800 <= ord(c) <= 0xDFFF))
def detect_lang(text: str) -> str:
if len(text) == 0:
return ""
text = text.replace("\n", "")
text = remove_invalid_surrogates(text)
# print(text)
try:
lang_upper = detect_language(text)
except:
......@@ -37,3 +45,4 @@ if __name__ == '__main__':
print(detect_lang("<html>This is a test</html>"))
print(detect_lang("这个是中文测试。"))
print(detect_lang("<html>这个是中文测试。</html>"))
print(detect_lang("〖\ud835\udc46\ud835〗这是个包含utf-16的中文测试"))
\ No newline at end of file
......@@ -7,19 +7,19 @@ from loguru import logger
from PIL import Image
from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_device
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
# from magic_pdf.data.dataset import Dataset
# from magic_pdf.libs.clean_memory import clean_memory
# from magic_pdf.libs.config_reader import get_device
# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
from magic_pdf.operators.models import InferenceResult
# from magic_pdf.operators.models import InferenceResult
YOLO_LAYOUT_BASE_BATCH_SIZE = 4
YOLO_LAYOUT_BASE_BATCH_SIZE = 1
MFD_BASE_BATCH_SIZE = 1
MFR_BASE_BATCH_SIZE = 16
......@@ -44,19 +44,20 @@ class BatchAnalyze:
modified_images = []
for image_index, image in enumerate(images):
pil_img = Image.fromarray(image)
width, height = pil_img.size
if height > width:
input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
new_image, useful_list = crop_img(
input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
)
layout_images.append(new_image)
modified_images.append([image_index, useful_list])
else:
layout_images.append(pil_img)
# width, height = pil_img.size
# if height > width:
# input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
# new_image, useful_list = crop_img(
# input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
# )
# layout_images.append(new_image)
# modified_images.append([image_index, useful_list])
# else:
layout_images.append(pil_img)
images_layout_res += self.model.layout_model.batch_predict(
layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
# layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
)
for image_index, useful_list in modified_images:
......@@ -78,7 +79,8 @@ class BatchAnalyze:
# 公式检测
mfd_start_time = time.time()
images_mfd_res = self.model.mfd_model.batch_predict(
images, self.batch_ratio * MFD_BASE_BATCH_SIZE
# images, self.batch_ratio * MFD_BASE_BATCH_SIZE
images, MFD_BASE_BATCH_SIZE
)
logger.info(
f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
......@@ -91,10 +93,12 @@ class BatchAnalyze:
images,
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
)
mfr_count = 0
for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index]
mfr_count += len(images_formula_list[image_index])
logger.info(
f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}'
f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
)
# 清理显存
......@@ -159,7 +163,7 @@ class BatchAnalyze:
elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.model.table_model.img2html(new_image)
elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = (
html_code, table_cell_bboxes, logic_points, elapse = (
self.model.table_model.predict(new_image)
)
run_time = time.time() - single_table_start_time
......@@ -195,81 +199,81 @@ class BatchAnalyze:
return images_layout_res
def doc_batch_analyze(
dataset: Dataset,
ocr: bool = False,
show_log: bool = False,
start_page_id=0,
end_page_id=None,
lang=None,
layout_model=None,
formula_enable=None,
table_enable=None,
batch_ratio: int | None = None,
) -> InferenceResult:
"""Perform batch analysis on a document dataset.
Args:
dataset (Dataset): The dataset containing document pages to be analyzed.
ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
show_log (bool, optional): Flag to enable logging. Defaults to False.
start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
lang (str, optional): Language for OCR. Defaults to None.
layout_model (optional): Layout model to be used for analysis. Defaults to None.
formula_enable (optional): Flag to enable formula detection. Defaults to None.
table_enable (optional): Flag to enable table detection. Defaults to None.
batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
Raises:
CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
Returns:
InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
"""
if not torch.cuda.is_available():
raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
lang = None if lang == '' else lang
# TODO: auto detect batch size
batch_ratio = 1 if batch_ratio is None else batch_ratio
end_page_id = end_page_id if end_page_id else len(dataset)
model_manager = ModelSingleton()
custom_model: CustomPEKModel = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable
)
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
model_json = []
# batch analyze
images = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
analyze_result = batch_model(images)
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0)
else:
result = []
page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
# TODO: clean memory when gpu memory is not enough
clean_memory_start_time = time.time()
clean_memory(get_device())
logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
return InferenceResult(model_json, dataset)
# def doc_batch_analyze(
# dataset: Dataset,
# ocr: bool = False,
# show_log: bool = False,
# start_page_id=0,
# end_page_id=None,
# lang=None,
# layout_model=None,
# formula_enable=None,
# table_enable=None,
# batch_ratio: int | None = None,
# ) -> InferenceResult:
# """Perform batch analysis on a document dataset.
#
# Args:
# dataset (Dataset): The dataset containing document pages to be analyzed.
# ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
# show_log (bool, optional): Flag to enable logging. Defaults to False.
# start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
# end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
# lang (str, optional): Language for OCR. Defaults to None.
# layout_model (optional): Layout model to be used for analysis. Defaults to None.
# formula_enable (optional): Flag to enable formula detection. Defaults to None.
# table_enable (optional): Flag to enable table detection. Defaults to None.
# batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
#
# Raises:
# CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
#
# Returns:
# InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
# """
#
# if not torch.cuda.is_available():
# raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
#
# lang = None if lang == '' else lang
# # TODO: auto detect batch size
# batch_ratio = 1 if batch_ratio is None else batch_ratio
# end_page_id = end_page_id if end_page_id else len(dataset)
#
# model_manager = ModelSingleton()
# custom_model: CustomPEKModel = model_manager.get_model(
# ocr, show_log, lang, layout_model, formula_enable, table_enable
# )
# batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
#
# model_json = []
#
# # batch analyze
# images = []
# for index in range(len(dataset)):
# if start_page_id <= index <= end_page_id:
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# images.append(img_dict['img'])
# analyze_result = batch_model(images)
#
# for index in range(len(dataset)):
# page_data = dataset.get_page(index)
# img_dict = page_data.get_image()
# page_width = img_dict['width']
# page_height = img_dict['height']
# if start_page_id <= index <= end_page_id:
# result = analyze_result.pop(0)
# else:
# result = []
#
# page_info = {'page_no': index, 'height': page_height, 'width': page_width}
# page_dict = {'layout_dets': result, 'page_info': page_info}
# model_json.append(page_dict)
#
# # TODO: clean memory when gpu memory is not enough
# clean_memory_start_time = time.time()
# clean_memory(get_device())
# logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
#
# return InferenceResult(model_json, dataset)
......@@ -3,8 +3,12 @@ import time
# 关闭paddle的信号处理
import paddle
import torch
from loguru import logger
from magic_pdf.model.batch_analyze import BatchAnalyze
from magic_pdf.model.sub_modules.model_utils import get_vram
paddle.disable_signal_handler()
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
......@@ -154,33 +158,88 @@ def doc_analyze(
table_enable=None,
) -> InferenceResult:
end_page_id = end_page_id if end_page_id else len(dataset) - 1
model_manager = ModelSingleton()
custom_model = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable
)
batch_analyze = False
device = get_device()
npu_support = False
if str(device).startswith("npu"):
import torch_npu
if torch_npu.npu.is_available():
npu_support = True
if torch.cuda.is_available() and device != 'cpu' or npu_support:
gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
if gpu_memory is not None and gpu_memory >= 8:
if 8 <= gpu_memory < 10:
batch_ratio = 2
elif 10 <= gpu_memory <= 12:
batch_ratio = 4
elif 12 < gpu_memory <= 16:
batch_ratio = 8
elif 16 < gpu_memory <= 24:
batch_ratio = 16
else:
batch_ratio = 32
if batch_ratio >= 1:
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
batch_analyze = True
model_json = []
doc_analyze_start = time.time()
if end_page_id is None:
end_page_id = len(dataset)
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
img = img_dict['img']
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
page_start = time.time()
result = custom_model(img)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
else:
result = []
if batch_analyze:
# batch analyze
images = []
for index in range(len(dataset)):
if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
images.append(img_dict['img'])
analyze_result = batch_model(images)
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0)
else:
result = []
page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
else:
# single analyze
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
img = img_dict['img']
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
page_start = time.time()
result = custom_model(img)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
else:
result = []
page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
gc_start = time.time()
clean_memory(get_device())
......
......@@ -69,6 +69,7 @@ class CustomPEKModel:
self.apply_table = self.table_config.get('enable', False)
self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
self.table_sub_model_name = self.table_config.get('sub_model', None)
# ocr config
self.apply_ocr = ocr
......@@ -144,7 +145,7 @@ class CustomPEKModel:
model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
)
),
device=self.device,
device='cpu' if str(self.device).startswith("mps") else self.device,
)
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
self.layout_model = atom_model_manager.get_atom_model(
......@@ -174,6 +175,7 @@ class CustomPEKModel:
table_max_time=self.table_max_time,
device=self.device,
ocr_engine=self.ocr_model,
table_sub_model_name=self.table_sub_model_name
)
logger.info('DocAnalysis init done!')
......@@ -192,24 +194,24 @@ class CustomPEKModel:
layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
if height > width:
input_res = {"poly":[0,0,width,0,width,height,0,height]}
new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
layout_res = self.layout_model.predict(new_image)
for res in layout_res:
p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
p1 = p1 - paste_x + xmin
p2 = p2 - paste_y + ymin
p3 = p3 - paste_x + xmin
p4 = p4 - paste_y + ymin
p5 = p5 - paste_x + xmin
p6 = p6 - paste_y + ymin
p7 = p7 - paste_x + xmin
p8 = p8 - paste_y + ymin
res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
else:
layout_res = self.layout_model.predict(image)
# if height > width:
# input_res = {"poly":[0,0,width,0,width,height,0,height]}
# new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
# paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# layout_res = self.layout_model.predict(new_image)
# for res in layout_res:
# p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
# p1 = p1 - paste_x + xmin
# p2 = p2 - paste_y + ymin
# p3 = p3 - paste_x + xmin
# p4 = p4 - paste_y + ymin
# p5 = p5 - paste_x + xmin
# p6 = p6 - paste_y + ymin
# p7 = p7 - paste_x + xmin
# p8 = p8 - paste_y + ymin
# res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
# else:
layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2)
logger.info(f'layout detection time: {layout_cost}')
......@@ -228,7 +230,7 @@ class CustomPEKModel:
logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
# 清理显存
clean_vram(self.device, vram_threshold=8)
clean_vram(self.device, vram_threshold=6)
# 从layout_res中获取ocr区域、表格区域、公式区域
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
......@@ -276,7 +278,7 @@ class CustomPEKModel:
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image)
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = self.table_model.predict(
html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
new_image
)
run_time = time.time() - single_table_start_time
......
......@@ -9,7 +9,11 @@ class DocLayoutYOLOModel(object):
def predict(self, image):
layout_res = []
doclayout_yolo_res = self.model.predict(
image, imgsz=1024, conf=0.25, iou=0.45, verbose=False, device=self.device
image,
imgsz=1280,
conf=0.10,
iou=0.45,
verbose=False, device=self.device
)[0]
for xyxy, conf, cla in zip(
doclayout_yolo_res.boxes.xyxy.cpu(),
......@@ -32,8 +36,8 @@ class DocLayoutYOLOModel(object):
image_res.cpu()
for image_res in self.model.predict(
images[index : index + batch_size],
imgsz=1024,
conf=0.25,
imgsz=1280,
conf=0.10,
iou=0.45,
verbose=False,
device=self.device,
......
......@@ -89,7 +89,7 @@ class UnimernetModel(object):
mf_image_list.append(bbox_img)
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
mfr_res = []
for mf_img in dataloader:
mf_img = mf_img.to(self.device)
......
......@@ -21,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None):
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER:
......@@ -31,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
}
table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTableModel(ocr_engine)
table_model = RapidTableModel(ocr_engine, table_sub_model_name)
else:
logger.error('table model type not allow')
exit(1)
......@@ -163,7 +163,8 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_path'),
kwargs.get('table_max_time'),
kwargs.get('device'),
kwargs.get('ocr_engine')
kwargs.get('ocr_engine'),
kwargs.get('table_sub_model_name')
)
elif model_name == AtomicModel.LangDetect:
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
......
......@@ -7,6 +7,8 @@ import base64
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
import importlib.resources
from paddleocr import PaddleOCR
from ppocr.utils.utility import check_and_read
......@@ -327,30 +329,35 @@ class ONNXModelSingleton:
return self._models[key]
def onnx_model_init(key):
import importlib.resources
resource_path = importlib.resources.path('rapidocr_onnxruntime.models','')
onnx_model = None
additional_ocr_params = {
"use_onnx": True,
"det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
"rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
"cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
"det_db_box_thresh": key[1],
"use_dilation": key[2],
"det_db_unclip_ratio": key[3],
}
# logger.info(f"additional_ocr_params: {additional_ocr_params}")
if key[0] is not None:
additional_ocr_params["lang"] = key[0]
from paddleocr import PaddleOCR
onnx_model = PaddleOCR(**additional_ocr_params)
if onnx_model is None:
logger.error('model init failed')
if len(key) < 4:
logger.error('Invalid key length, expected at least 4 elements')
exit(1)
else:
return onnx_model
\ No newline at end of file
try:
with importlib.resources.path('rapidocr_onnxruntime.models', '') as resource_path:
additional_ocr_params = {
"use_onnx": True,
"det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
"rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
"cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
"det_db_box_thresh": key[1],
"use_dilation": key[2],
"det_db_unclip_ratio": key[3],
}
if key[0] is not None:
additional_ocr_params["lang"] = key[0]
# logger.info(f"additional_ocr_params: {additional_ocr_params}")
onnx_model = PaddleOCR(**additional_ocr_params)
if onnx_model is None:
logger.error('model init failed')
exit(1)
else:
return onnx_model
except Exception as e:
logger.exception(f'Error initializing model: {e}')
exit(1)
\ No newline at end of file
......@@ -2,12 +2,27 @@ import cv2
import numpy as np
import torch
from loguru import logger
from rapid_table import RapidTable
from rapid_table import RapidTable, RapidTableInput
from rapid_table.main import ModelType
from magic_pdf.libs.config_reader import get_device
class RapidTableModel(object):
def __init__(self, ocr_engine):
self.table_model = RapidTable()
def __init__(self, ocr_engine, table_sub_model_name):
sub_model_list = [model.value for model in ModelType]
if table_sub_model_name is None:
input_args = RapidTableInput()
elif table_sub_model_name in sub_model_list:
if torch.cuda.is_available() and table_sub_model_name == "unitable":
input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
else:
input_args = RapidTableInput(model_type=table_sub_model_name)
else:
raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
self.table_model = RapidTable(input_args)
# if ocr_engine is None:
# self.ocr_model_name = "RapidOCR"
# if torch.cuda.is_available():
......@@ -45,7 +60,11 @@ class RapidTableModel(object):
ocr_result = None
if ocr_result:
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
return html_code, table_cell_bboxes, elapse
table_results = self.table_model(np.asarray(image), ocr_result)
html_code = table_results.pred_html
table_cell_bboxes = table_results.cell_bboxes
logic_points = table_results.logic_points
elapse = table_results.elapse
return html_code, table_cell_bboxes, logic_points, elapse
else:
return None, None, None
return None, None, None, None
import copy
import math
import os
import re
import statistics
......@@ -12,7 +13,7 @@ from loguru import logger
from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.data.dataset import Dataset, PageableData
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, __is_overlaps_y_exceeds_threshold
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
from magic_pdf.libs.convert_utils import dict_to_list
......@@ -117,9 +118,10 @@ def fill_char_in_spans(spans, all_chars):
for char in all_chars:
# 跳过非法bbox的char
x1, y1, x2, y2 = char['bbox']
if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
continue
# x1, y1, x2, y2 = char['bbox']
# if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
# continue
for span in spans:
if calculate_char_in_span(char['bbox'], span['bbox'], char['c']):
span['chars'].append(char)
......@@ -173,12 +175,35 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
return False
def remove_tilted_line(text_blocks):
for block in text_blocks:
remove_lines = []
for line in block['lines']:
cosine, sine = line['dir']
# 计算弧度值
angle_radians = math.atan2(sine, cosine)
# 将弧度值转换为角度值
angle_degrees = math.degrees(angle_radians)
if 2 < abs(angle_degrees) < 88:
remove_lines.append(line)
for line in remove_lines:
block['lines'].remove(line)
def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
# cid用0xfffd表示,连字符拆开
# text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
# cid用0xfffd表示,连字符不拆开
text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
#text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
# 自定义flags出现较多0xfffd,可能是pymupdf可以自行处理内置字典的pdf,不再使用
text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
# text_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
# 移除所有角度不为0或90的line
remove_tilted_line(text_blocks_raw)
all_pymu_chars = []
for block in text_blocks_raw:
for line in block['lines']:
......@@ -365,10 +390,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
block['index'] = median_value
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
if block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.Title, BlockType.InterlineEquation]:
if 'real_lines' in block:
block['virtual_lines'] = copy.deepcopy(block['lines'])
block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
else:
# 使用xycut排序
block_bboxes = []
......@@ -417,7 +443,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
block_weight = x1 - x0
# 如果block高度小于n行正文,则直接返回block的bbox
if line_height * 3 < block_height:
if line_height * 2 < block_height:
if (
block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
): # 可能是双列结构,可以切细点
......@@ -425,16 +451,16 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
else:
# 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
if block_weight > page_w * 0.4:
line_height = (y1 - y0) / 3
lines = 3
line_height = (y1 - y0) / lines
elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点)
lines = int(block_height / line_height) + 1
else: # 判断长宽比
if block_height / block_weight > 1.2: # 细长的不分
return [[x0, y0, x1, y1]]
else: # 不细长的还是分成两行
line_height = (y1 - y0) / 2
lines = 2
line_height = (y1 - y0) / lines
# 确定从哪个y位置开始绘制线条
current_y = y0
......@@ -453,30 +479,32 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
page_line_list = []
def add_lines_to_block(b):
line_bboxes = insert_lines_into_block(b['bbox'], line_height, page_w, page_h)
b['lines'] = []
for line_bbox in line_bboxes:
b['lines'].append({'bbox': line_bbox, 'spans': []})
page_line_list.extend(line_bboxes)
for block in fix_blocks:
if block['type'] in [
BlockType.Text, BlockType.Title, BlockType.InterlineEquation,
BlockType.Text, BlockType.Title,
BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote
]:
if len(block['lines']) == 0:
bbox = block['bbox']
lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
for line in lines:
block['lines'].append({'bbox': line, 'spans': []})
page_line_list.extend(lines)
add_lines_to_block(block)
elif block['type'] in [BlockType.Title] and len(block['lines']) == 1 and (block['bbox'][3] - block['bbox'][1]) > line_height * 2:
block['real_lines'] = copy.deepcopy(block['lines'])
add_lines_to_block(block)
else:
for line in block['lines']:
bbox = line['bbox']
page_line_list.append(bbox)
elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
bbox = block['bbox']
elif block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.InterlineEquation]:
block['real_lines'] = copy.deepcopy(block['lines'])
lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
block['lines'] = []
for line in lines:
block['lines'].append({'bbox': line, 'spans': []})
page_line_list.extend(lines)
add_lines_to_block(block)
if len(page_line_list) > 200: # layoutreader最高支持512line
return None
......@@ -663,12 +691,77 @@ def parse_page_core(
discarded_blocks = magic_model.get_discarded(page_id)
text_blocks = magic_model.get_text_blocks(page_id)
title_blocks = magic_model.get_title_blocks(page_id)
inline_equations, interline_equations, interline_equation_blocks = (
magic_model.get_equations(page_id)
)
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id)
page_w, page_h = magic_model.get_page_size(page_id)
def merge_title_blocks(blocks, x_distance_threshold=0.1*page_w):
def merge_two_bbox(b1, b2):
x_min = min(b1['bbox'][0], b2['bbox'][0])
y_min = min(b1['bbox'][1], b2['bbox'][1])
x_max = max(b1['bbox'][2], b2['bbox'][2])
y_max = max(b1['bbox'][3], b2['bbox'][3])
return x_min, y_min, x_max, y_max
def merge_two_blocks(b1, b2):
# 合并两个标题块的边界框
b1['bbox'] = merge_two_bbox(b1, b2)
# 合并两个标题块的文本内容
line1 = b1['lines'][0]
line2 = b2['lines'][0]
line1['bbox'] = merge_two_bbox(line1, line2)
line1['spans'].extend(line2['spans'])
return b1, b2
# 按 y 轴重叠度聚集标题块
y_overlapping_blocks = []
title_bs = [b for b in blocks if b['type'] == BlockType.Title]
while title_bs:
block1 = title_bs.pop(0)
current_row = [block1]
to_remove = []
for block2 in title_bs:
if (
__is_overlaps_y_exceeds_threshold(block1['bbox'], block2['bbox'], 0.9)
and len(block1['lines']) == 1
and len(block2['lines']) == 1
):
current_row.append(block2)
to_remove.append(block2)
for b in to_remove:
title_bs.remove(b)
y_overlapping_blocks.append(current_row)
# 按x轴坐标排序并合并标题块
to_remove_blocks = []
for row in y_overlapping_blocks:
if len(row) == 1:
continue
# 按x轴坐标排序
row.sort(key=lambda x: x['bbox'][0])
merged_block = row[0]
for i in range(1, len(row)):
left_block = merged_block
right_block = row[i]
left_height = left_block['bbox'][3] - left_block['bbox'][1]
right_height = right_block['bbox'][3] - right_block['bbox'][1]
if (
right_block['bbox'][0] - left_block['bbox'][2] < x_distance_threshold
and left_height * 0.95 < right_height < left_height * 1.05
):
merged_block, to_remove_block = merge_two_blocks(merged_block, right_block)
to_remove_blocks.append(to_remove_block)
else:
merged_block = right_block
for b in to_remove_blocks:
blocks.remove(b)
"""将所有区块的bbox整理到一起"""
# interline_equation_blocks参数不够准,后面切换到interline_equations上
interline_equation_blocks = []
......@@ -753,6 +846,9 @@ def parse_page_core(
"""对block进行fix操作"""
fix_blocks = fix_block_spans_v2(block_with_spans)
"""同一行被断开的titile合并"""
merge_title_blocks(fix_blocks)
"""获取所有line并计算正文line的高度"""
line_height = get_line_height(fix_blocks)
......@@ -861,17 +957,23 @@ def pdf_parse_union(
formula_aided_config = llm_aided_config.get('formula_aided', None)
if formula_aided_config is not None:
if formula_aided_config.get('enable', False):
llm_aided_formula_start_time = time.time()
llm_aided_formula(pdf_info_dict, formula_aided_config)
logger.info(f'llm aided formula time: {round(time.time() - llm_aided_formula_start_time, 2)}')
"""文本优化"""
text_aided_config = llm_aided_config.get('text_aided', None)
if text_aided_config is not None:
if text_aided_config.get('enable', False):
llm_aided_text_start_time = time.time()
llm_aided_text(pdf_info_dict, text_aided_config)
logger.info(f'llm aided text time: {round(time.time() - llm_aided_text_start_time, 2)}')
"""标题优化"""
title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None:
if title_aided_config.get('enable', False):
llm_aided_title_start_time = time.time()
llm_aided_title(pdf_info_dict, title_aided_config)
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
"""dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict)
......
......@@ -83,26 +83,47 @@ def llm_aided_title(pdf_info_dict, title_aided_config):
if block["type"] == "title":
origin_title_list.append(block)
title_text = merge_para_with_text(block)
title_dict[f"{i}"] = title_text
page_line_height_list = []
for line in block['lines']:
bbox = line['bbox']
page_line_height_list.append(int(bbox[3] - bbox[1]))
if len(page_line_height_list) > 0:
line_avg_height = sum(page_line_height_list) / len(page_line_height_list)
else:
line_avg_height = int(block['bbox'][3] - block['bbox'][1])
title_dict[f"{i}"] = [title_text, line_avg_height, int(page_num[5:])+1]
i += 1
# logger.info(f"Title list: {title_dict}")
title_optimize_prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构:
1. 保留原始内容:
1. 字典中每个value均为一个list,包含以下元素:
- 标题文本
- 文本行高是标题所在块的平均行高
- 标题所在的页码
2. 保留原始内容:
- 输入的字典中所有元素都是有效的,不能删除字典中的任何元素
- 请务必保证输出的字典中元素的数量和输入的数量一致
2. 保持字典内key-value的对应关系不变
3. 保持字典内key-value的对应关系不变
3. 优化层次结构:
4. 优化层次结构:
- 为每个标题元素添加适当的层次结构
- 标题层级应具有连续性,不能跳过某一层级
- 行高较大的标题一般是更高级别的标题
- 标题从前至后的层级必须是连续的,不能跳过层级
- 标题层级最多为4级,不要添加过多的层级
- 优化后的标题为一个整数,代表该标题的层级
- 优化后的标题只保留代表该标题的层级的整数,不要保留其他信息
5. 合理性检查与微调:
- 在完成初步分级后,仔细检查分级结果的合理性
- 根据上下文关系和逻辑顺序,对不合理的分级进行微调
- 确保最终的分级结果符合文档的实际结构和逻辑
IMPORTANT:
请直接返回优化过的由标题层级组成的json,返回的json不需要格式化。
请直接返回优化过的由标题层级组成的json,格式如下:
{{"0":1,"1":2,"2":2,"3":3}}
返回的json不需要格式化。
Input title list:
{title_dict}
......@@ -110,24 +131,36 @@ Input title list:
Corrected title list:
"""
completion = client.chat.completions.create(
model=title_aided_config["model"],
messages=[
{'role': 'user', 'content': title_optimize_prompt}],
temperature=0.7,
)
json_completion = json.loads(completion.choices[0].message.content)
# logger.info(f"Title completion: {json_completion}")
retry_count = 0
max_retries = 3
json_completion = None
# logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}")
if len(json_completion) == len(title_dict):
while retry_count < max_retries:
try:
for i, origin_title_block in enumerate(origin_title_list):
origin_title_block["level"] = int(json_completion[str(i)])
completion = client.chat.completions.create(
model=title_aided_config["model"],
messages=[
{'role': 'user', 'content': title_optimize_prompt}],
temperature=0.7,
)
json_completion = json.loads(completion.choices[0].message.content)
# logger.info(f"Title completion: {json_completion}")
# logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}")
if len(json_completion) == len(title_dict):
for i, origin_title_block in enumerate(origin_title_list):
origin_title_block["level"] = int(json_completion[str(i)])
break
else:
logger.warning("The number of titles in the optimized result is not equal to the number of titles in the input.")
retry_count += 1
except Exception as e:
logger.exception(e)
else:
logger.error("The number of titles in the optimized result is not equal to the number of titles in the input.")
if isinstance(e, json.decoder.JSONDecodeError):
logger.warning(f"JSON decode error on attempt {retry_count + 1}: {e}")
else:
logger.exception(e)
retry_count += 1
if json_completion is None:
logger.error("Failed to decode JSON after maximum retries.")
......@@ -36,7 +36,7 @@ def remove_overlaps_low_confidence_spans(spans):
def check_chars_is_overlap_in_span(chars):
for i in range(len(chars)):
for j in range(i + 1, len(chars)):
if calculate_iou(chars[i]['bbox'], chars[j]['bbox']) > 0.9:
if calculate_iou(chars[i]['bbox'], chars[j]['bbox']) > 0.35:
return True
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment