"vscode:/vscode.git/clone" did not exist on "70eb449c08ff457df990d9d293964ecbb1854e5d"
Unverified Commit 19f72c23 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1614 from opendatalab/release-1.1.0

Release 1.1.0
parents 4d70b16b adcace44
This diff is collapsed.
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
[![Downloads](https://static.pepy.tech/badge/magic-pdf)](https://pepy.tech/project/magic-pdf) [![Downloads](https://static.pepy.tech/badge/magic-pdf)](https://pepy.tech/project/magic-pdf)
[![Downloads](https://static.pepy.tech/badge/magic-pdf/month)](https://pepy.tech/project/magic-pdf) [![Downloads](https://static.pepy.tech/badge/magic-pdf/month)](https://pepy.tech/project/magic-pdf)
[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.org.cn/OpenSourceTools/Extractor?source=github) [![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.net/OpenSourceTools/Extractor?source=github)
[![ModelScope](https://img.shields.io/badge/Demo_on_ModelScope-purple?logo=&labelColor=white)](https://www.modelscope.cn/studios/OpenDataLab/MinerU) [![ModelScope](https://img.shields.io/badge/Demo_on_ModelScope-purple?logo=&labelColor=white)](https://www.modelscope.cn/studios/OpenDataLab/MinerU)
[![HuggingFace](https://img.shields.io/badge/Demo_on_HuggingFace-yellow.svg?logo=&labelColor=white)](https://huggingface.co/spaces/opendatalab/MinerU) [![HuggingFace](https://img.shields.io/badge/Demo_on_HuggingFace-yellow.svg?logo=&labelColor=white)](https://huggingface.co/spaces/opendatalab/MinerU)
[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/myhloli/3b3a00a4a0a61577b6c30f989092d20d/mineru_demo.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/myhloli/3b3a00a4a0a61577b6c30f989092d20d/mineru_demo.ipynb)
...@@ -33,19 +33,27 @@ ...@@ -33,19 +33,27 @@
<a href="https://github.com/opendatalab/PDF-Extract-Kit">PDF-Extract-Kit: 高质量PDF解析工具箱</a>🔥🔥🔥 <a href="https://github.com/opendatalab/PDF-Extract-Kit">PDF-Extract-Kit: 高质量PDF解析工具箱</a>🔥🔥🔥
<br> <br>
<br> <br>
<a href="https://mineru.org.cn/client?source=github">更便捷的使用方式:MinerU桌面端。无需编程,无需登录,图形界面,简单交互,畅用无忧。</a>🚀🚀🚀 <a href="https://mineru.net/client?source=github">更便捷的使用方式:MinerU桌面端。无需编程,无需登录,图形界面,简单交互,畅用无忧。</a>🚀🚀🚀
</p> </p>
<!-- join us --> <!-- join us -->
<p align="center"> <p align="center">
👋 join us on <a href="https://discord.gg/Tdedn9GTXq" target="_blank">Discord</a> and <a href="https://cdn.vansin.top/internlm/mineru.jpg" target="_blank">WeChat</a> 👋 join us on <a href="https://discord.gg/Tdedn9GTXq" target="_blank">Discord</a> and <a href="http://mineru.space/s/V85Yl" target="_blank">WeChat</a>
</p> </p>
</div> </div>
# 更新记录 # 更新记录
- 2025/01/22 1.1.0 发布,在这个版本我们重点提升了解析的精度与效率:
- 模型能力升级(需重新执行[模型下载流程](docs/how_to_download_models_zh_cn.md)以获得模型文件的增量更新)
- 布局识别模型升级到最新的`doclayout_yolo(2501)`模型,提升了layout识别精度
- 公式解析模型升级到最新的`unimernet(2501)`模型,提升了公式识别精度
- 性能优化
- 在配置满足一定条件(显存16GB+)的设备上,通过优化资源占用和重构处理流水线,整体解析速度提升50%以上
- 解析效果优化
- 在线demo([mineru.net](https://mineru.net/OpenSourceTools/Extractor)/[huggingface](https://huggingface.co/spaces/opendatalab/MinerU)/[modelscope](https://www.modelscope.cn/studios/OpenDataLab/MinerU))上新增标题分级功能(测试版本,默认开启),支持对标题进行分级,提升文档结构化程度
- 2025/01/10 1.0.1 发布,这是我们的第一个正式版本,在这个版本中,我们通过大量重构带来了全新的API接口和更广泛的兼容性,以及全新的自动语言识别功能: - 2025/01/10 1.0.1 发布,这是我们的第一个正式版本,在这个版本中,我们通过大量重构带来了全新的API接口和更广泛的兼容性,以及全新的自动语言识别功能:
- 全新API接口 - 全新API接口
- 对于数据侧API,我们引入了Dataset类,旨在提供一个强大而灵活的数据处理框架。该框架当前支持包括图像(.jpg及.png)、PDF、Word(.doc及.docx)、以及PowerPoint(.ppt及.pptx)在内的多种文档格式,确保了从简单到复杂的数据处理任务都能得到有效的支持。 - 对于数据侧API,我们引入了Dataset类,旨在提供一个强大而灵活的数据处理框架。该框架当前支持包括图像(.jpg及.png)、PDF、Word(.doc及.docx)、以及PowerPoint(.ppt及.pptx)在内的多种文档格式,确保了从简单到复杂的数据处理任务都能得到有效的支持。
...@@ -227,7 +235,7 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c ...@@ -227,7 +235,7 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
### 在线体验 ### 在线体验
稳定版(经过QA验证的稳定版本): 稳定版(经过QA验证的稳定版本):
[![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.org.cn/OpenSourceTools/Extractor?source=github) [![OpenDataLab](https://img.shields.io/badge/Demo_on_OpenDataLab-blue?logo=&labelColor=white)](https://mineru.net/OpenSourceTools/Extractor?source=github)
测试版(同步dev分支更新,测试新特性): 测试版(同步dev分支更新,测试新特性):
...@@ -277,6 +285,7 @@ pip install -U "magic-pdf[full]" --extra-index-url https://wheels.myhloli.com -i ...@@ -277,6 +285,7 @@ pip install -U "magic-pdf[full]" --extra-index-url https://wheels.myhloli.com -i
}, },
"table-config": { "table-config": {
"model": "rapid_table", // 默认使用"rapid_table",可以切换为"tablemaster"和"struct_eqtable" "model": "rapid_table", // 默认使用"rapid_table",可以切换为"tablemaster"和"struct_eqtable"
"sub_model": "slanet_plus", // 当model为"rapid_table"时,可以自选sub_model,可选项为"slanet_plus"和"unitable"
"enable": true, // 表格识别功能默认是开启的,如果需要关闭请修改此处的值为"false" "enable": true, // 表格识别功能默认是开启的,如果需要关闭请修改此处的值为"false"
"max_time": 400 "max_time": 400
} }
...@@ -359,6 +368,7 @@ TODO ...@@ -359,6 +368,7 @@ TODO
- [x] 基于模型的阅读顺序 - [x] 基于模型的阅读顺序
- [x] 正文中目录、列表识别 - [x] 正文中目录、列表识别
- [x] 表格识别 - [x] 表格识别
- [x] 标题分级
- [ ] 正文中代码块识别 - [ ] 正文中代码块识别
- [ ] [化学式识别](docs/chemical_knowledge_introduction/introduction.pdf) - [ ] [化学式识别](docs/chemical_knowledge_introduction/introduction.pdf)
- [ ] 几何图形识别 - [ ] 几何图形识别
...@@ -368,7 +378,6 @@ TODO ...@@ -368,7 +378,6 @@ TODO
- 阅读顺序基于模型对可阅读内容在空间中的分布进行排序,在极端复杂的排版下可能会部分区域乱序 - 阅读顺序基于模型对可阅读内容在空间中的分布进行排序,在极端复杂的排版下可能会部分区域乱序
- 不支持竖排文字 - 不支持竖排文字
- 目录和列表通过规则进行识别,少部分不常见的列表形式可能无法识别 - 目录和列表通过规则进行识别,少部分不常见的列表形式可能无法识别
- 标题只有一级,目前不支持标题分级
- 代码块在layout模型里还没有支持 - 代码块在layout模型里还没有支持
- 漫画书、艺术图册、小学教材、习题尚不能很好解析 - 漫画书、艺术图册、小学教材、习题尚不能很好解析
- 表格识别在复杂表格上可能会出现行/列识别错误 - 表格识别在复杂表格上可能会出现行/列识别错误
......
boto3>=1.28.43 boto3>=1.28.43
Brotli>=1.1.0 Brotli>=1.1.0
click>=8.1.7 click>=8.1.7
PyMuPDF>=1.24.9 PyMuPDF>=1.24.9,<=1.24.14
loguru>=0.6.0 loguru>=0.6.0
numpy>=1.21.6,<2.0.0 numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0 fast-langdetect>=0.2.3,<0.3.0
...@@ -17,10 +17,9 @@ paddlepaddle==3.0.0b1 ...@@ -17,10 +17,9 @@ paddlepaddle==3.0.0b1
struct-eqtable==0.3.2 struct-eqtable==0.3.2
einops einops
accelerate accelerate
doclayout_yolo==0.0.2
rapidocr-paddle rapidocr-paddle
rapidocr-onnxruntime rapidocr-onnxruntime
rapid_table==0.3.0 rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2 doclayout-yolo==0.0.2b1
openai openai
detectron2 detectron2
boto3>=1.28.43 boto3>=1.28.43
Brotli>=1.1.0 Brotli>=1.1.0
click>=8.1.7 click>=8.1.7
PyMuPDF>=1.24.9 PyMuPDF>=1.24.9,<=1.24.14
loguru>=0.6.0 loguru>=0.6.0
numpy>=1.21.6,<2.0.0 numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0 fast-langdetect>=0.2.3,<0.3.0
...@@ -16,10 +16,9 @@ paddleocr==2.7.3 ...@@ -16,10 +16,9 @@ paddleocr==2.7.3
struct-eqtable==0.3.2 struct-eqtable==0.3.2
einops einops
accelerate accelerate
doclayout_yolo==0.0.2
rapidocr-paddle rapidocr-paddle
rapidocr-onnxruntime rapidocr-onnxruntime
rapid_table==0.3.0 rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2 doclayout-yolo==0.0.2b1
openai openai
detectron2 detectron2
boto3>=1.28.43 boto3>=1.28.43
Brotli>=1.1.0 Brotli>=1.1.0
click>=8.1.7 click>=8.1.7
PyMuPDF>=1.24.9 PyMuPDF>=1.24.9,<=1.24.14
loguru>=0.6.0 loguru>=0.6.0
numpy>=1.21.6,<2.0.0 numpy>=1.21.6,<2.0.0
fast-langdetect>=0.2.3,<0.3.0 fast-langdetect>=0.2.3,<0.3.0
...@@ -16,10 +16,9 @@ paddleocr==2.7.3 ...@@ -16,10 +16,9 @@ paddleocr==2.7.3
struct-eqtable==0.3.2 struct-eqtable==0.3.2
einops einops
accelerate accelerate
doclayout_yolo==0.0.2
rapidocr-paddle rapidocr-paddle
rapidocr-onnxruntime rapidocr-onnxruntime
rapid_table==0.3.0 rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2 doclayout-yolo==0.0.2b1
openai openai
detectron2 detectron2
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
}, },
"table-config": { "table-config": {
"model": "rapid_table", "model": "rapid_table",
"sub_model": "slanet_plus",
"enable": true, "enable": true,
"max_time": 400 "max_time": 400
}, },
...@@ -39,5 +40,5 @@ ...@@ -39,5 +40,5 @@
"enable": false "enable": false
} }
}, },
"config_version": "1.1.0" "config_version": "1.1.1"
} }
\ No newline at end of file
...@@ -185,10 +185,13 @@ def calculate_iou(bbox1, bbox2): ...@@ -185,10 +185,13 @@ def calculate_iou(bbox1, bbox2):
bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
if any([bbox1_area == 0, bbox2_area == 0]):
return 0
# Compute the intersection over union by taking the intersection area # Compute the intersection over union by taking the intersection area
# and dividing it by the sum of both areas minus the intersection area # and dividing it by the sum of both areas minus the intersection area
iou = intersection_area / float(bbox1_area + bbox2_area - iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area)
intersection_area)
return iou return iou
......
...@@ -362,12 +362,24 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -362,12 +362,24 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
for page in pdf_info: for page in pdf_info:
page_line_list = [] page_line_list = []
for block in page['preproc_blocks']: for block in page['preproc_blocks']:
if block['type'] in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]: if block['type'] in [BlockType.Text]:
for line in block['lines']: for line in block['lines']:
bbox = line['bbox'] bbox = line['bbox']
index = line['index'] index = line['index']
page_line_list.append({'index': index, 'bbox': bbox}) page_line_list.append({'index': index, 'bbox': bbox})
if block['type'] in [BlockType.Image, BlockType.Table]: elif block['type'] in [BlockType.Title, BlockType.InterlineEquation]:
if 'virtual_lines' in block:
if len(block['virtual_lines']) > 0 and block['virtual_lines'][0].get('index', None) is not None:
for line in block['virtual_lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
else:
for line in block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
elif block['type'] in [BlockType.Image, BlockType.Table]:
for sub_block in block['blocks']: for sub_block in block['blocks']:
if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]: if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None: if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
......
...@@ -12,12 +12,20 @@ if not os.getenv("FTLANG_CACHE"): ...@@ -12,12 +12,20 @@ if not os.getenv("FTLANG_CACHE"):
from fast_langdetect import detect_language from fast_langdetect import detect_language
def remove_invalid_surrogates(text):
# 移除无效的 UTF-16 代理对
return ''.join(c for c in text if not (0xD800 <= ord(c) <= 0xDFFF))
def detect_lang(text: str) -> str: def detect_lang(text: str) -> str:
if len(text) == 0: if len(text) == 0:
return "" return ""
text = text.replace("\n", "") text = text.replace("\n", "")
text = remove_invalid_surrogates(text)
# print(text)
try: try:
lang_upper = detect_language(text) lang_upper = detect_language(text)
except: except:
...@@ -37,3 +45,4 @@ if __name__ == '__main__': ...@@ -37,3 +45,4 @@ if __name__ == '__main__':
print(detect_lang("<html>This is a test</html>")) print(detect_lang("<html>This is a test</html>"))
print(detect_lang("这个是中文测试。")) print(detect_lang("这个是中文测试。"))
print(detect_lang("<html>这个是中文测试。</html>")) print(detect_lang("<html>这个是中文测试。</html>"))
print(detect_lang("〖\ud835\udc46\ud835〗这是个包含utf-16的中文测试"))
\ No newline at end of file
...@@ -7,19 +7,19 @@ from loguru import logger ...@@ -7,19 +7,19 @@ from loguru import logger
from PIL import Image from PIL import Image
from magic_pdf.config.constants import MODEL_NAME from magic_pdf.config.constants import MODEL_NAME
from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE # from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
from magic_pdf.data.dataset import Dataset # from magic_pdf.data.dataset import Dataset
from magic_pdf.libs.clean_memory import clean_memory # from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_device # from magic_pdf.libs.config_reader import get_device
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton # from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from magic_pdf.model.pdf_extract_kit import CustomPEKModel from magic_pdf.model.pdf_extract_kit import CustomPEKModel
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res) clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list) get_adjusted_mfdetrec_res, get_ocr_result_list)
from magic_pdf.operators.models import InferenceResult # from magic_pdf.operators.models import InferenceResult
YOLO_LAYOUT_BASE_BATCH_SIZE = 4 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
MFD_BASE_BATCH_SIZE = 1 MFD_BASE_BATCH_SIZE = 1
MFR_BASE_BATCH_SIZE = 16 MFR_BASE_BATCH_SIZE = 16
...@@ -44,19 +44,20 @@ class BatchAnalyze: ...@@ -44,19 +44,20 @@ class BatchAnalyze:
modified_images = [] modified_images = []
for image_index, image in enumerate(images): for image_index, image in enumerate(images):
pil_img = Image.fromarray(image) pil_img = Image.fromarray(image)
width, height = pil_img.size # width, height = pil_img.size
if height > width: # if height > width:
input_res = {'poly': [0, 0, width, 0, width, height, 0, height]} # input_res = {'poly': [0, 0, width, 0, width, height, 0, height]}
new_image, useful_list = crop_img( # new_image, useful_list = crop_img(
input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0 # input_res, pil_img, crop_paste_x=width // 2, crop_paste_y=0
) # )
layout_images.append(new_image) # layout_images.append(new_image)
modified_images.append([image_index, useful_list]) # modified_images.append([image_index, useful_list])
else: # else:
layout_images.append(pil_img) layout_images.append(pil_img)
images_layout_res += self.model.layout_model.batch_predict( images_layout_res += self.model.layout_model.batch_predict(
layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE # layout_images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
) )
for image_index, useful_list in modified_images: for image_index, useful_list in modified_images:
...@@ -78,7 +79,8 @@ class BatchAnalyze: ...@@ -78,7 +79,8 @@ class BatchAnalyze:
# 公式检测 # 公式检测
mfd_start_time = time.time() mfd_start_time = time.time()
images_mfd_res = self.model.mfd_model.batch_predict( images_mfd_res = self.model.mfd_model.batch_predict(
images, self.batch_ratio * MFD_BASE_BATCH_SIZE # images, self.batch_ratio * MFD_BASE_BATCH_SIZE
images, MFD_BASE_BATCH_SIZE
) )
logger.info( logger.info(
f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}' f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
...@@ -91,10 +93,12 @@ class BatchAnalyze: ...@@ -91,10 +93,12 @@ class BatchAnalyze:
images, images,
batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE, batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
) )
mfr_count = 0
for image_index in range(len(images)): for image_index in range(len(images)):
images_layout_res[image_index] += images_formula_list[image_index] images_layout_res[image_index] += images_formula_list[image_index]
mfr_count += len(images_formula_list[image_index])
logger.info( logger.info(
f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}' f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
) )
# 清理显存 # 清理显存
...@@ -159,7 +163,7 @@ class BatchAnalyze: ...@@ -159,7 +163,7 @@ class BatchAnalyze:
elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER: elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.model.table_model.img2html(new_image) html_code = self.model.table_model.img2html(new_image)
elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE: elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = ( html_code, table_cell_bboxes, logic_points, elapse = (
self.model.table_model.predict(new_image) self.model.table_model.predict(new_image)
) )
run_time = time.time() - single_table_start_time run_time = time.time() - single_table_start_time
...@@ -195,81 +199,81 @@ class BatchAnalyze: ...@@ -195,81 +199,81 @@ class BatchAnalyze:
return images_layout_res return images_layout_res
def doc_batch_analyze( # def doc_batch_analyze(
dataset: Dataset, # dataset: Dataset,
ocr: bool = False, # ocr: bool = False,
show_log: bool = False, # show_log: bool = False,
start_page_id=0, # start_page_id=0,
end_page_id=None, # end_page_id=None,
lang=None, # lang=None,
layout_model=None, # layout_model=None,
formula_enable=None, # formula_enable=None,
table_enable=None, # table_enable=None,
batch_ratio: int | None = None, # batch_ratio: int | None = None,
) -> InferenceResult: # ) -> InferenceResult:
"""Perform batch analysis on a document dataset. # """Perform batch analysis on a document dataset.
#
Args: # Args:
dataset (Dataset): The dataset containing document pages to be analyzed. # dataset (Dataset): The dataset containing document pages to be analyzed.
ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False. # ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
show_log (bool, optional): Flag to enable logging. Defaults to False. # show_log (bool, optional): Flag to enable logging. Defaults to False.
start_page_id (int, optional): The starting page ID for analysis. Defaults to 0. # start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page. # end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
lang (str, optional): Language for OCR. Defaults to None. # lang (str, optional): Language for OCR. Defaults to None.
layout_model (optional): Layout model to be used for analysis. Defaults to None. # layout_model (optional): Layout model to be used for analysis. Defaults to None.
formula_enable (optional): Flag to enable formula detection. Defaults to None. # formula_enable (optional): Flag to enable formula detection. Defaults to None.
table_enable (optional): Flag to enable table detection. Defaults to None. # table_enable (optional): Flag to enable table detection. Defaults to None.
batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1. # batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
#
Raises: # Raises:
CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode. # CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
#
Returns: # Returns:
InferenceResult: The result of the batch analysis containing the analyzed data and the dataset. # InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
""" # """
#
if not torch.cuda.is_available(): # if not torch.cuda.is_available():
raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode') # raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
#
lang = None if lang == '' else lang # lang = None if lang == '' else lang
# TODO: auto detect batch size # # TODO: auto detect batch size
batch_ratio = 1 if batch_ratio is None else batch_ratio # batch_ratio = 1 if batch_ratio is None else batch_ratio
end_page_id = end_page_id if end_page_id else len(dataset) # end_page_id = end_page_id if end_page_id else len(dataset)
#
model_manager = ModelSingleton() # model_manager = ModelSingleton()
custom_model: CustomPEKModel = model_manager.get_model( # custom_model: CustomPEKModel = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable # ocr, show_log, lang, layout_model, formula_enable, table_enable
) # )
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio) # batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
#
model_json = [] # model_json = []
#
# batch analyze # # batch analyze
images = [] # images = []
for index in range(len(dataset)): # for index in range(len(dataset)):
if start_page_id <= index <= end_page_id: # if start_page_id <= index <= end_page_id:
page_data = dataset.get_page(index) # page_data = dataset.get_page(index)
img_dict = page_data.get_image() # img_dict = page_data.get_image()
images.append(img_dict['img']) # images.append(img_dict['img'])
analyze_result = batch_model(images) # analyze_result = batch_model(images)
#
for index in range(len(dataset)): # for index in range(len(dataset)):
page_data = dataset.get_page(index) # page_data = dataset.get_page(index)
img_dict = page_data.get_image() # img_dict = page_data.get_image()
page_width = img_dict['width'] # page_width = img_dict['width']
page_height = img_dict['height'] # page_height = img_dict['height']
if start_page_id <= index <= end_page_id: # if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0) # result = analyze_result.pop(0)
else: # else:
result = [] # result = []
#
page_info = {'page_no': index, 'height': page_height, 'width': page_width} # page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info} # page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict) # model_json.append(page_dict)
#
# TODO: clean memory when gpu memory is not enough # # TODO: clean memory when gpu memory is not enough
clean_memory_start_time = time.time() # clean_memory_start_time = time.time()
clean_memory(get_device()) # clean_memory(get_device())
logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}') # logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
#
return InferenceResult(model_json, dataset) # return InferenceResult(model_json, dataset)
...@@ -3,8 +3,12 @@ import time ...@@ -3,8 +3,12 @@ import time
# 关闭paddle的信号处理 # 关闭paddle的信号处理
import paddle import paddle
import torch
from loguru import logger from loguru import logger
from magic_pdf.model.batch_analyze import BatchAnalyze
from magic_pdf.model.sub_modules.model_utils import get_vram
paddle.disable_signal_handler() paddle.disable_signal_handler()
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
...@@ -154,33 +158,88 @@ def doc_analyze( ...@@ -154,33 +158,88 @@ def doc_analyze(
table_enable=None, table_enable=None,
) -> InferenceResult: ) -> InferenceResult:
end_page_id = end_page_id if end_page_id else len(dataset) - 1
model_manager = ModelSingleton() model_manager = ModelSingleton()
custom_model = model_manager.get_model( custom_model = model_manager.get_model(
ocr, show_log, lang, layout_model, formula_enable, table_enable ocr, show_log, lang, layout_model, formula_enable, table_enable
) )
batch_analyze = False
device = get_device()
npu_support = False
if str(device).startswith("npu"):
import torch_npu
if torch_npu.npu.is_available():
npu_support = True
if torch.cuda.is_available() and device != 'cpu' or npu_support:
gpu_memory = int(os.getenv("VIRTUAL_VRAM_SIZE", round(get_vram(device))))
if gpu_memory is not None and gpu_memory >= 8:
if 8 <= gpu_memory < 10:
batch_ratio = 2
elif 10 <= gpu_memory <= 12:
batch_ratio = 4
elif 12 < gpu_memory <= 16:
batch_ratio = 8
elif 16 < gpu_memory <= 24:
batch_ratio = 16
else:
batch_ratio = 32
if batch_ratio >= 1:
logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
batch_analyze = True
model_json = [] model_json = []
doc_analyze_start = time.time() doc_analyze_start = time.time()
if end_page_id is None: if batch_analyze:
end_page_id = len(dataset) # batch analyze
images = []
for index in range(len(dataset)): for index in range(len(dataset)):
page_data = dataset.get_page(index) if start_page_id <= index <= end_page_id:
img_dict = page_data.get_image() page_data = dataset.get_page(index)
img = img_dict['img'] img_dict = page_data.get_image()
page_width = img_dict['width'] images.append(img_dict['img'])
page_height = img_dict['height'] analyze_result = batch_model(images)
if start_page_id <= index <= end_page_id:
page_start = time.time() for index in range(len(dataset)):
result = custom_model(img) page_data = dataset.get_page(index)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----') img_dict = page_data.get_image()
else: page_width = img_dict['width']
result = [] page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
result = analyze_result.pop(0)
else:
result = []
page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
page_info = {'page_no': index, 'height': page_height, 'width': page_width} else:
page_dict = {'layout_dets': result, 'page_info': page_info} # single analyze
model_json.append(page_dict)
for index in range(len(dataset)):
page_data = dataset.get_page(index)
img_dict = page_data.get_image()
img = img_dict['img']
page_width = img_dict['width']
page_height = img_dict['height']
if start_page_id <= index <= end_page_id:
page_start = time.time()
result = custom_model(img)
logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
else:
result = []
page_info = {'page_no': index, 'height': page_height, 'width': page_width}
page_dict = {'layout_dets': result, 'page_info': page_info}
model_json.append(page_dict)
gc_start = time.time() gc_start = time.time()
clean_memory(get_device()) clean_memory(get_device())
......
...@@ -69,6 +69,7 @@ class CustomPEKModel: ...@@ -69,6 +69,7 @@ class CustomPEKModel:
self.apply_table = self.table_config.get('enable', False) self.apply_table = self.table_config.get('enable', False)
self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE) self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE) self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
self.table_sub_model_name = self.table_config.get('sub_model', None)
# ocr config # ocr config
self.apply_ocr = ocr self.apply_ocr = ocr
...@@ -144,7 +145,7 @@ class CustomPEKModel: ...@@ -144,7 +145,7 @@ class CustomPEKModel:
model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml' model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
) )
), ),
device=self.device, device='cpu' if str(self.device).startswith("mps") else self.device,
) )
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
self.layout_model = atom_model_manager.get_atom_model( self.layout_model = atom_model_manager.get_atom_model(
...@@ -174,6 +175,7 @@ class CustomPEKModel: ...@@ -174,6 +175,7 @@ class CustomPEKModel:
table_max_time=self.table_max_time, table_max_time=self.table_max_time,
device=self.device, device=self.device,
ocr_engine=self.ocr_model, ocr_engine=self.ocr_model,
table_sub_model_name=self.table_sub_model_name
) )
logger.info('DocAnalysis init done!') logger.info('DocAnalysis init done!')
...@@ -192,24 +194,24 @@ class CustomPEKModel: ...@@ -192,24 +194,24 @@ class CustomPEKModel:
layout_res = self.layout_model(image, ignore_catids=[]) layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo # doclayout_yolo
if height > width: # if height > width:
input_res = {"poly":[0,0,width,0,width,height,0,height]} # input_res = {"poly":[0,0,width,0,width,height,0,height]}
new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0) # new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list # paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
layout_res = self.layout_model.predict(new_image) # layout_res = self.layout_model.predict(new_image)
for res in layout_res: # for res in layout_res:
p1, p2, p3, p4, p5, p6, p7, p8 = res['poly'] # p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
p1 = p1 - paste_x + xmin # p1 = p1 - paste_x + xmin
p2 = p2 - paste_y + ymin # p2 = p2 - paste_y + ymin
p3 = p3 - paste_x + xmin # p3 = p3 - paste_x + xmin
p4 = p4 - paste_y + ymin # p4 = p4 - paste_y + ymin
p5 = p5 - paste_x + xmin # p5 = p5 - paste_x + xmin
p6 = p6 - paste_y + ymin # p6 = p6 - paste_y + ymin
p7 = p7 - paste_x + xmin # p7 = p7 - paste_x + xmin
p8 = p8 - paste_y + ymin # p8 = p8 - paste_y + ymin
res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8] # res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
else: # else:
layout_res = self.layout_model.predict(image) layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2) layout_cost = round(time.time() - layout_start, 2)
logger.info(f'layout detection time: {layout_cost}') logger.info(f'layout detection time: {layout_cost}')
...@@ -228,7 +230,7 @@ class CustomPEKModel: ...@@ -228,7 +230,7 @@ class CustomPEKModel:
logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}') logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
# 清理显存 # 清理显存
clean_vram(self.device, vram_threshold=8) clean_vram(self.device, vram_threshold=6)
# 从layout_res中获取ocr区域、表格区域、公式区域 # 从layout_res中获取ocr区域、表格区域、公式区域
ocr_res_list, table_res_list, single_page_mfdetrec_res = ( ocr_res_list, table_res_list, single_page_mfdetrec_res = (
...@@ -276,7 +278,7 @@ class CustomPEKModel: ...@@ -276,7 +278,7 @@ class CustomPEKModel:
elif self.table_model_name == MODEL_NAME.TABLE_MASTER: elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image) html_code = self.table_model.img2html(new_image)
elif self.table_model_name == MODEL_NAME.RAPID_TABLE: elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = self.table_model.predict( html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
new_image new_image
) )
run_time = time.time() - single_table_start_time run_time = time.time() - single_table_start_time
......
...@@ -9,7 +9,11 @@ class DocLayoutYOLOModel(object): ...@@ -9,7 +9,11 @@ class DocLayoutYOLOModel(object):
def predict(self, image): def predict(self, image):
layout_res = [] layout_res = []
doclayout_yolo_res = self.model.predict( doclayout_yolo_res = self.model.predict(
image, imgsz=1024, conf=0.25, iou=0.45, verbose=False, device=self.device image,
imgsz=1280,
conf=0.10,
iou=0.45,
verbose=False, device=self.device
)[0] )[0]
for xyxy, conf, cla in zip( for xyxy, conf, cla in zip(
doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.xyxy.cpu(),
...@@ -32,8 +36,8 @@ class DocLayoutYOLOModel(object): ...@@ -32,8 +36,8 @@ class DocLayoutYOLOModel(object):
image_res.cpu() image_res.cpu()
for image_res in self.model.predict( for image_res in self.model.predict(
images[index : index + batch_size], images[index : index + batch_size],
imgsz=1024, imgsz=1280,
conf=0.25, conf=0.10,
iou=0.45, iou=0.45,
verbose=False, verbose=False,
device=self.device, device=self.device,
......
...@@ -89,7 +89,7 @@ class UnimernetModel(object): ...@@ -89,7 +89,7 @@ class UnimernetModel(object):
mf_image_list.append(bbox_img) mf_image_list.append(bbox_img)
dataset = MathDataset(mf_image_list, transform=self.mfr_transform) dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
dataloader = DataLoader(dataset, batch_size=64, num_workers=0) dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
mfr_res = [] mfr_res = []
for mf_img in dataloader: for mf_img in dataloader:
mf_img = mf_img.to(self.device) mf_img = mf_img.to(self.device)
......
...@@ -21,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \ ...@@ -21,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None): def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE: if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time) table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER: elif table_model_type == MODEL_NAME.TABLE_MASTER:
...@@ -31,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr ...@@ -31,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
} }
table_model = TableMasterPaddleModel(config) table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE: elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTableModel(ocr_engine) table_model = RapidTableModel(ocr_engine, table_sub_model_name)
else: else:
logger.error('table model type not allow') logger.error('table model type not allow')
exit(1) exit(1)
...@@ -163,7 +163,8 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -163,7 +163,8 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_path'), kwargs.get('table_model_path'),
kwargs.get('table_max_time'), kwargs.get('table_max_time'),
kwargs.get('device'), kwargs.get('device'),
kwargs.get('ocr_engine') kwargs.get('ocr_engine'),
kwargs.get('table_sub_model_name')
) )
elif model_name == AtomicModel.LangDetect: elif model_name == AtomicModel.LangDetect:
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect: if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
......
...@@ -7,6 +7,8 @@ import base64 ...@@ -7,6 +7,8 @@ import base64
from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
import importlib.resources
from paddleocr import PaddleOCR
from ppocr.utils.utility import check_and_read from ppocr.utils.utility import check_and_read
...@@ -327,30 +329,35 @@ class ONNXModelSingleton: ...@@ -327,30 +329,35 @@ class ONNXModelSingleton:
return self._models[key] return self._models[key]
def onnx_model_init(key): def onnx_model_init(key):
if len(key) < 4:
import importlib.resources logger.error('Invalid key length, expected at least 4 elements')
resource_path = importlib.resources.path('rapidocr_onnxruntime.models','')
onnx_model = None
additional_ocr_params = {
"use_onnx": True,
"det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
"rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
"cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
"det_db_box_thresh": key[1],
"use_dilation": key[2],
"det_db_unclip_ratio": key[3],
}
# logger.info(f"additional_ocr_params: {additional_ocr_params}")
if key[0] is not None:
additional_ocr_params["lang"] = key[0]
from paddleocr import PaddleOCR
onnx_model = PaddleOCR(**additional_ocr_params)
if onnx_model is None:
logger.error('model init failed')
exit(1) exit(1)
else:
return onnx_model try:
\ No newline at end of file with importlib.resources.path('rapidocr_onnxruntime.models', '') as resource_path:
additional_ocr_params = {
"use_onnx": True,
"det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
"rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
"cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
"det_db_box_thresh": key[1],
"use_dilation": key[2],
"det_db_unclip_ratio": key[3],
}
if key[0] is not None:
additional_ocr_params["lang"] = key[0]
# logger.info(f"additional_ocr_params: {additional_ocr_params}")
onnx_model = PaddleOCR(**additional_ocr_params)
if onnx_model is None:
logger.error('model init failed')
exit(1)
else:
return onnx_model
except Exception as e:
logger.exception(f'Error initializing model: {e}')
exit(1)
\ No newline at end of file
...@@ -2,12 +2,27 @@ import cv2 ...@@ -2,12 +2,27 @@ import cv2
import numpy as np import numpy as np
import torch import torch
from loguru import logger from loguru import logger
from rapid_table import RapidTable from rapid_table import RapidTable, RapidTableInput
from rapid_table.main import ModelType
from magic_pdf.libs.config_reader import get_device
class RapidTableModel(object): class RapidTableModel(object):
def __init__(self, ocr_engine): def __init__(self, ocr_engine, table_sub_model_name):
self.table_model = RapidTable() sub_model_list = [model.value for model in ModelType]
if table_sub_model_name is None:
input_args = RapidTableInput()
elif table_sub_model_name in sub_model_list:
if torch.cuda.is_available() and table_sub_model_name == "unitable":
input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True, device=get_device())
else:
input_args = RapidTableInput(model_type=table_sub_model_name)
else:
raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
self.table_model = RapidTable(input_args)
# if ocr_engine is None: # if ocr_engine is None:
# self.ocr_model_name = "RapidOCR" # self.ocr_model_name = "RapidOCR"
# if torch.cuda.is_available(): # if torch.cuda.is_available():
...@@ -45,7 +60,11 @@ class RapidTableModel(object): ...@@ -45,7 +60,11 @@ class RapidTableModel(object):
ocr_result = None ocr_result = None
if ocr_result: if ocr_result:
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result) table_results = self.table_model(np.asarray(image), ocr_result)
return html_code, table_cell_bboxes, elapse html_code = table_results.pred_html
table_cell_bboxes = table_results.cell_bboxes
logic_points = table_results.logic_points
elapse = table_results.elapse
return html_code, table_cell_bboxes, logic_points, elapse
else: else:
return None, None, None return None, None, None, None
import copy import copy
import math
import os import os
import re import re
import statistics import statistics
...@@ -12,7 +13,7 @@ from loguru import logger ...@@ -12,7 +13,7 @@ from loguru import logger
from magic_pdf.config.enums import SupportedPdfParseMethod from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.config.ocr_content_type import BlockType, ContentType from magic_pdf.config.ocr_content_type import BlockType, ContentType
from magic_pdf.data.dataset import Dataset, PageableData from magic_pdf.data.dataset import Dataset, PageableData
from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, __is_overlaps_y_exceeds_threshold
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
from magic_pdf.libs.convert_utils import dict_to_list from magic_pdf.libs.convert_utils import dict_to_list
...@@ -117,9 +118,10 @@ def fill_char_in_spans(spans, all_chars): ...@@ -117,9 +118,10 @@ def fill_char_in_spans(spans, all_chars):
for char in all_chars: for char in all_chars:
# 跳过非法bbox的char # 跳过非法bbox的char
x1, y1, x2, y2 = char['bbox'] # x1, y1, x2, y2 = char['bbox']
if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01: # if abs(x1 - x2) <= 0.01 or abs(y1 - y2) <= 0.01:
continue # continue
for span in spans: for span in spans:
if calculate_char_in_span(char['bbox'], span['bbox'], char['c']): if calculate_char_in_span(char['bbox'], span['bbox'], char['c']):
span['chars'].append(char) span['chars'].append(char)
...@@ -173,12 +175,35 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33): ...@@ -173,12 +175,35 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
return False return False
def remove_tilted_line(text_blocks):
for block in text_blocks:
remove_lines = []
for line in block['lines']:
cosine, sine = line['dir']
# 计算弧度值
angle_radians = math.atan2(sine, cosine)
# 将弧度值转换为角度值
angle_degrees = math.degrees(angle_radians)
if 2 < abs(angle_degrees) < 88:
remove_lines.append(line)
for line in remove_lines:
block['lines'].remove(line)
def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang): def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
# cid用0xfffd表示,连字符拆开 # cid用0xfffd表示,连字符拆开
# text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks'] # text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
# cid用0xfffd表示,连字符不拆开 # cid用0xfffd表示,连字符不拆开
text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks'] #text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_LIGATURES | fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
# 自定义flags出现较多0xfffd,可能是pymupdf可以自行处理内置字典的pdf,不再使用
text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
# text_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
# 移除所有角度不为0或90的line
remove_tilted_line(text_blocks_raw)
all_pymu_chars = [] all_pymu_chars = []
for block in text_blocks_raw: for block in text_blocks_raw:
for line in block['lines']: for line in block['lines']:
...@@ -365,10 +390,11 @@ def cal_block_index(fix_blocks, sorted_bboxes): ...@@ -365,10 +390,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
block['index'] = median_value block['index'] = median_value
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填 # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if block['type'] in [BlockType.ImageBody, BlockType.TableBody]: if block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.Title, BlockType.InterlineEquation]:
block['virtual_lines'] = copy.deepcopy(block['lines']) if 'real_lines' in block:
block['lines'] = copy.deepcopy(block['real_lines']) block['virtual_lines'] = copy.deepcopy(block['lines'])
del block['real_lines'] block['lines'] = copy.deepcopy(block['real_lines'])
del block['real_lines']
else: else:
# 使用xycut排序 # 使用xycut排序
block_bboxes = [] block_bboxes = []
...@@ -417,7 +443,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h): ...@@ -417,7 +443,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
block_weight = x1 - x0 block_weight = x1 - x0
# 如果block高度小于n行正文,则直接返回block的bbox # 如果block高度小于n行正文,则直接返回block的bbox
if line_height * 3 < block_height: if line_height * 2 < block_height:
if ( if (
block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25 block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
): # 可能是双列结构,可以切细点 ): # 可能是双列结构,可以切细点
...@@ -425,16 +451,16 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h): ...@@ -425,16 +451,16 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
else: else:
# 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细) # 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
if block_weight > page_w * 0.4: if block_weight > page_w * 0.4:
line_height = (y1 - y0) / 3
lines = 3 lines = 3
line_height = (y1 - y0) / lines
elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点) elif block_weight > page_w * 0.25: # (可能是三列结构,也切细点)
lines = int(block_height / line_height) + 1 lines = int(block_height / line_height) + 1
else: # 判断长宽比 else: # 判断长宽比
if block_height / block_weight > 1.2: # 细长的不分 if block_height / block_weight > 1.2: # 细长的不分
return [[x0, y0, x1, y1]] return [[x0, y0, x1, y1]]
else: # 不细长的还是分成两行 else: # 不细长的还是分成两行
line_height = (y1 - y0) / 2
lines = 2 lines = 2
line_height = (y1 - y0) / lines
# 确定从哪个y位置开始绘制线条 # 确定从哪个y位置开始绘制线条
current_y = y0 current_y = y0
...@@ -453,30 +479,32 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h): ...@@ -453,30 +479,32 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
def sort_lines_by_model(fix_blocks, page_w, page_h, line_height): def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
page_line_list = [] page_line_list = []
def add_lines_to_block(b):
line_bboxes = insert_lines_into_block(b['bbox'], line_height, page_w, page_h)
b['lines'] = []
for line_bbox in line_bboxes:
b['lines'].append({'bbox': line_bbox, 'spans': []})
page_line_list.extend(line_bboxes)
for block in fix_blocks: for block in fix_blocks:
if block['type'] in [ if block['type'] in [
BlockType.Text, BlockType.Title, BlockType.InterlineEquation, BlockType.Text, BlockType.Title,
BlockType.ImageCaption, BlockType.ImageFootnote, BlockType.ImageCaption, BlockType.ImageFootnote,
BlockType.TableCaption, BlockType.TableFootnote BlockType.TableCaption, BlockType.TableFootnote
]: ]:
if len(block['lines']) == 0: if len(block['lines']) == 0:
bbox = block['bbox'] add_lines_to_block(block)
lines = insert_lines_into_block(bbox, line_height, page_w, page_h) elif block['type'] in [BlockType.Title] and len(block['lines']) == 1 and (block['bbox'][3] - block['bbox'][1]) > line_height * 2:
for line in lines: block['real_lines'] = copy.deepcopy(block['lines'])
block['lines'].append({'bbox': line, 'spans': []}) add_lines_to_block(block)
page_line_list.extend(lines)
else: else:
for line in block['lines']: for line in block['lines']:
bbox = line['bbox'] bbox = line['bbox']
page_line_list.append(bbox) page_line_list.append(bbox)
elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]: elif block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.InterlineEquation]:
bbox = block['bbox']
block['real_lines'] = copy.deepcopy(block['lines']) block['real_lines'] = copy.deepcopy(block['lines'])
lines = insert_lines_into_block(bbox, line_height, page_w, page_h) add_lines_to_block(block)
block['lines'] = []
for line in lines:
block['lines'].append({'bbox': line, 'spans': []})
page_line_list.extend(lines)
if len(page_line_list) > 200: # layoutreader最高支持512line if len(page_line_list) > 200: # layoutreader最高支持512line
return None return None
...@@ -663,12 +691,77 @@ def parse_page_core( ...@@ -663,12 +691,77 @@ def parse_page_core(
discarded_blocks = magic_model.get_discarded(page_id) discarded_blocks = magic_model.get_discarded(page_id)
text_blocks = magic_model.get_text_blocks(page_id) text_blocks = magic_model.get_text_blocks(page_id)
title_blocks = magic_model.get_title_blocks(page_id) title_blocks = magic_model.get_title_blocks(page_id)
inline_equations, interline_equations, interline_equation_blocks = ( inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id)
magic_model.get_equations(page_id)
)
page_w, page_h = magic_model.get_page_size(page_id) page_w, page_h = magic_model.get_page_size(page_id)
def merge_title_blocks(blocks, x_distance_threshold=0.1*page_w):
def merge_two_bbox(b1, b2):
x_min = min(b1['bbox'][0], b2['bbox'][0])
y_min = min(b1['bbox'][1], b2['bbox'][1])
x_max = max(b1['bbox'][2], b2['bbox'][2])
y_max = max(b1['bbox'][3], b2['bbox'][3])
return x_min, y_min, x_max, y_max
def merge_two_blocks(b1, b2):
# 合并两个标题块的边界框
b1['bbox'] = merge_two_bbox(b1, b2)
# 合并两个标题块的文本内容
line1 = b1['lines'][0]
line2 = b2['lines'][0]
line1['bbox'] = merge_two_bbox(line1, line2)
line1['spans'].extend(line2['spans'])
return b1, b2
# 按 y 轴重叠度聚集标题块
y_overlapping_blocks = []
title_bs = [b for b in blocks if b['type'] == BlockType.Title]
while title_bs:
block1 = title_bs.pop(0)
current_row = [block1]
to_remove = []
for block2 in title_bs:
if (
__is_overlaps_y_exceeds_threshold(block1['bbox'], block2['bbox'], 0.9)
and len(block1['lines']) == 1
and len(block2['lines']) == 1
):
current_row.append(block2)
to_remove.append(block2)
for b in to_remove:
title_bs.remove(b)
y_overlapping_blocks.append(current_row)
# 按x轴坐标排序并合并标题块
to_remove_blocks = []
for row in y_overlapping_blocks:
if len(row) == 1:
continue
# 按x轴坐标排序
row.sort(key=lambda x: x['bbox'][0])
merged_block = row[0]
for i in range(1, len(row)):
left_block = merged_block
right_block = row[i]
left_height = left_block['bbox'][3] - left_block['bbox'][1]
right_height = right_block['bbox'][3] - right_block['bbox'][1]
if (
right_block['bbox'][0] - left_block['bbox'][2] < x_distance_threshold
and left_height * 0.95 < right_height < left_height * 1.05
):
merged_block, to_remove_block = merge_two_blocks(merged_block, right_block)
to_remove_blocks.append(to_remove_block)
else:
merged_block = right_block
for b in to_remove_blocks:
blocks.remove(b)
"""将所有区块的bbox整理到一起""" """将所有区块的bbox整理到一起"""
# interline_equation_blocks参数不够准,后面切换到interline_equations上 # interline_equation_blocks参数不够准,后面切换到interline_equations上
interline_equation_blocks = [] interline_equation_blocks = []
...@@ -753,6 +846,9 @@ def parse_page_core( ...@@ -753,6 +846,9 @@ def parse_page_core(
"""对block进行fix操作""" """对block进行fix操作"""
fix_blocks = fix_block_spans_v2(block_with_spans) fix_blocks = fix_block_spans_v2(block_with_spans)
"""同一行被断开的titile合并"""
merge_title_blocks(fix_blocks)
"""获取所有line并计算正文line的高度""" """获取所有line并计算正文line的高度"""
line_height = get_line_height(fix_blocks) line_height = get_line_height(fix_blocks)
...@@ -861,17 +957,23 @@ def pdf_parse_union( ...@@ -861,17 +957,23 @@ def pdf_parse_union(
formula_aided_config = llm_aided_config.get('formula_aided', None) formula_aided_config = llm_aided_config.get('formula_aided', None)
if formula_aided_config is not None: if formula_aided_config is not None:
if formula_aided_config.get('enable', False): if formula_aided_config.get('enable', False):
llm_aided_formula_start_time = time.time()
llm_aided_formula(pdf_info_dict, formula_aided_config) llm_aided_formula(pdf_info_dict, formula_aided_config)
logger.info(f'llm aided formula time: {round(time.time() - llm_aided_formula_start_time, 2)}')
"""文本优化""" """文本优化"""
text_aided_config = llm_aided_config.get('text_aided', None) text_aided_config = llm_aided_config.get('text_aided', None)
if text_aided_config is not None: if text_aided_config is not None:
if text_aided_config.get('enable', False): if text_aided_config.get('enable', False):
llm_aided_text_start_time = time.time()
llm_aided_text(pdf_info_dict, text_aided_config) llm_aided_text(pdf_info_dict, text_aided_config)
logger.info(f'llm aided text time: {round(time.time() - llm_aided_text_start_time, 2)}')
"""标题优化""" """标题优化"""
title_aided_config = llm_aided_config.get('title_aided', None) title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None: if title_aided_config is not None:
if title_aided_config.get('enable', False): if title_aided_config.get('enable', False):
llm_aided_title_start_time = time.time()
llm_aided_title(pdf_info_dict, title_aided_config) llm_aided_title(pdf_info_dict, title_aided_config)
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
"""dict转list""" """dict转list"""
pdf_info_list = dict_to_list(pdf_info_dict) pdf_info_list = dict_to_list(pdf_info_dict)
......
...@@ -83,26 +83,47 @@ def llm_aided_title(pdf_info_dict, title_aided_config): ...@@ -83,26 +83,47 @@ def llm_aided_title(pdf_info_dict, title_aided_config):
if block["type"] == "title": if block["type"] == "title":
origin_title_list.append(block) origin_title_list.append(block)
title_text = merge_para_with_text(block) title_text = merge_para_with_text(block)
title_dict[f"{i}"] = title_text page_line_height_list = []
for line in block['lines']:
bbox = line['bbox']
page_line_height_list.append(int(bbox[3] - bbox[1]))
if len(page_line_height_list) > 0:
line_avg_height = sum(page_line_height_list) / len(page_line_height_list)
else:
line_avg_height = int(block['bbox'][3] - block['bbox'][1])
title_dict[f"{i}"] = [title_text, line_avg_height, int(page_num[5:])+1]
i += 1 i += 1
# logger.info(f"Title list: {title_dict}") # logger.info(f"Title list: {title_dict}")
title_optimize_prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构: title_optimize_prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构:
1. 保留原始内容: 1. 字典中每个value均为一个list,包含以下元素:
- 标题文本
- 文本行高是标题所在块的平均行高
- 标题所在的页码
2. 保留原始内容:
- 输入的字典中所有元素都是有效的,不能删除字典中的任何元素 - 输入的字典中所有元素都是有效的,不能删除字典中的任何元素
- 请务必保证输出的字典中元素的数量和输入的数量一致 - 请务必保证输出的字典中元素的数量和输入的数量一致
2. 保持字典内key-value的对应关系不变 3. 保持字典内key-value的对应关系不变
3. 优化层次结构: 4. 优化层次结构:
- 为每个标题元素添加适当的层次结构 - 为每个标题元素添加适当的层次结构
- 标题层级应具有连续性,不能跳过某一层级 - 行高较大的标题一般是更高级别的标题
- 标题从前至后的层级必须是连续的,不能跳过层级
- 标题层级最多为4级,不要添加过多的层级 - 标题层级最多为4级,不要添加过多的层级
- 优化后的标题为一个整数,代表该标题的层级 - 优化后的标题只保留代表该标题的层级的整数,不要保留其他信息
5. 合理性检查与微调:
- 在完成初步分级后,仔细检查分级结果的合理性
- 根据上下文关系和逻辑顺序,对不合理的分级进行微调
- 确保最终的分级结果符合文档的实际结构和逻辑
IMPORTANT: IMPORTANT:
请直接返回优化过的由标题层级组成的json,返回的json不需要格式化。 请直接返回优化过的由标题层级组成的json,格式如下:
{{"0":1,"1":2,"2":2,"3":3}}
返回的json不需要格式化。
Input title list: Input title list:
{title_dict} {title_dict}
...@@ -110,24 +131,36 @@ Input title list: ...@@ -110,24 +131,36 @@ Input title list:
Corrected title list: Corrected title list:
""" """
completion = client.chat.completions.create( retry_count = 0
model=title_aided_config["model"], max_retries = 3
messages=[ json_completion = None
{'role': 'user', 'content': title_optimize_prompt}],
temperature=0.7,
)
json_completion = json.loads(completion.choices[0].message.content)
# logger.info(f"Title completion: {json_completion}")
# logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}") while retry_count < max_retries:
if len(json_completion) == len(title_dict):
try: try:
for i, origin_title_block in enumerate(origin_title_list): completion = client.chat.completions.create(
origin_title_block["level"] = int(json_completion[str(i)]) model=title_aided_config["model"],
messages=[
{'role': 'user', 'content': title_optimize_prompt}],
temperature=0.7,
)
json_completion = json.loads(completion.choices[0].message.content)
# logger.info(f"Title completion: {json_completion}")
# logger.info(f"len(json_completion): {len(json_completion)}, len(title_dict): {len(title_dict)}")
if len(json_completion) == len(title_dict):
for i, origin_title_block in enumerate(origin_title_list):
origin_title_block["level"] = int(json_completion[str(i)])
break
else:
logger.warning("The number of titles in the optimized result is not equal to the number of titles in the input.")
retry_count += 1
except Exception as e: except Exception as e:
logger.exception(e) if isinstance(e, json.decoder.JSONDecodeError):
else: logger.warning(f"JSON decode error on attempt {retry_count + 1}: {e}")
logger.error("The number of titles in the optimized result is not equal to the number of titles in the input.") else:
logger.exception(e)
retry_count += 1
if json_completion is None:
logger.error("Failed to decode JSON after maximum retries.")
...@@ -36,7 +36,7 @@ def remove_overlaps_low_confidence_spans(spans): ...@@ -36,7 +36,7 @@ def remove_overlaps_low_confidence_spans(spans):
def check_chars_is_overlap_in_span(chars): def check_chars_is_overlap_in_span(chars):
for i in range(len(chars)): for i in range(len(chars)):
for j in range(i + 1, len(chars)): for j in range(i + 1, len(chars)):
if calculate_iou(chars[i]['bbox'], chars[j]['bbox']) > 0.9: if calculate_iou(chars[i]['bbox'], chars[j]['bbox']) > 0.35:
return True return True
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment