Unverified Commit 845a3ff0 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #969 from opendatalab/release-0.9.3

Release 0.9.3
parents d0558abb 6083e109
...@@ -48,3 +48,6 @@ debug_utils/ ...@@ -48,3 +48,6 @@ debug_utils/
# sphinx docs # sphinx docs
_build/ _build/
output/
\ No newline at end of file
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
</div> </div>
# Changelog # Changelog
- 2024/11/15 0.9.3 released. Integrated [RapidTable](https://github.com/RapidAI/RapidTable) for table recognition, improving single-table parsing speed by more than 10 times, with higher accuracy and lower GPU memory usage.
- 2024/11/06 0.9.2 released. Integrated the [StructTable-InternVL2-1B](https://huggingface.co/U4R/StructTable-InternVL2-1B) model for table recognition functionality. - 2024/11/06 0.9.2 released. Integrated the [StructTable-InternVL2-1B](https://huggingface.co/U4R/StructTable-InternVL2-1B) model for table recognition functionality.
- 2024/10/31 0.9.0 released. This is a major new version with extensive code refactoring, addressing numerous issues, improving performance, reducing hardware requirements, and enhancing usability: - 2024/10/31 0.9.0 released. This is a major new version with extensive code refactoring, addressing numerous issues, improving performance, reducing hardware requirements, and enhancing usability:
- Refactored the sorting module code to use [layoutreader](https://github.com/ppaanngggg/layoutreader) for reading order sorting, ensuring high accuracy in various layouts. - Refactored the sorting module code to use [layoutreader](https://github.com/ppaanngggg/layoutreader) for reading order sorting, ensuring high accuracy in various layouts.
...@@ -246,7 +247,7 @@ You can modify certain configurations in this file to enable or disable features ...@@ -246,7 +247,7 @@ You can modify certain configurations in this file to enable or disable features
"enable": true // The formula recognition feature is enabled by default. If you need to disable it, please change the value here to "false". "enable": true // The formula recognition feature is enabled by default. If you need to disable it, please change the value here to "false".
}, },
"table-config": { "table-config": {
"model": "tablemaster", // When using structEqTable, please change to "struct_eqtable". "model": "rapid_table", // When using structEqTable, please change to "struct_eqtable".
"enable": false, // The table recognition feature is disabled by default. If you need to enable it, please change the value here to "true". "enable": false, // The table recognition feature is disabled by default. If you need to enable it, please change the value here to "true".
"max_time": 400 "max_time": 400
} }
...@@ -261,7 +262,7 @@ If your device supports CUDA and meets the GPU requirements of the mainline envi ...@@ -261,7 +262,7 @@ If your device supports CUDA and meets the GPU requirements of the mainline envi
- [Windows 10/11 + GPU](docs/README_Windows_CUDA_Acceleration_en_US.md) - [Windows 10/11 + GPU](docs/README_Windows_CUDA_Acceleration_en_US.md)
- Quick Deployment with Docker - Quick Deployment with Docker
> [!IMPORTANT] > [!IMPORTANT]
> Docker requires a GPU with at least 16GB of VRAM, and all acceleration features are enabled by default. > Docker requires a GPU with at least 8GB of VRAM, and all acceleration features are enabled by default.
> >
> Before running this Docker, you can use the following command to check if your device supports CUDA acceleration on Docker. > Before running this Docker, you can use the following command to check if your device supports CUDA acceleration on Docker.
> >
...@@ -421,7 +422,9 @@ This project currently uses PyMuPDF to achieve advanced functionality. However, ...@@ -421,7 +422,9 @@ This project currently uses PyMuPDF to achieve advanced functionality. However,
# Acknowledgments # Acknowledgments
- [PDF-Extract-Kit](https://github.com/opendatalab/PDF-Extract-Kit) - [PDF-Extract-Kit](https://github.com/opendatalab/PDF-Extract-Kit)
- [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO)
- [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy) - [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
- [RapidTable](https://github.com/RapidAI/RapidTable)
- [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) - [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
- [PyMuPDF](https://github.com/pymupdf/PyMuPDF) - [PyMuPDF](https://github.com/pymupdf/PyMuPDF)
- [layoutreader](https://github.com/ppaanngggg/layoutreader) - [layoutreader](https://github.com/ppaanngggg/layoutreader)
......
> [!Warning]
> このドキュメントはすでに古くなっています。最新版のドキュメントを参照してください:[ENGLISH](README.md)。
<div id="top"> <div id="top">
<p align="center"> <p align="center">
...@@ -18,9 +20,7 @@ ...@@ -18,9 +20,7 @@
<a href="https://trendshift.io/repositories/11174" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11174" alt="opendatalab%2FMinerU | Trendshift" style="width: 200px; height: 55px;"/></a> <a href="https://trendshift.io/repositories/11174" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11174" alt="opendatalab%2FMinerU | Trendshift" style="width: 200px; height: 55px;"/></a>
<div align="center" style="color: red; background-color: #ffdddd; padding: 10px; border: 1px solid red; border-radius: 5px;">
<strong>NOTE:</strong> このドキュメントはすでに古くなっています。最新版のドキュメントを参照してください。
</div>
[English](README.md) | [简体中文](README_zh-CN.md) | [日本語](README_ja-JP.md) [English](README.md) | [简体中文](README_zh-CN.md) | [日本語](README_ja-JP.md)
......
...@@ -42,7 +42,7 @@ ...@@ -42,7 +42,7 @@
</div> </div>
# 更新记录 # 更新记录
- 2024/11/15 0.9.3发布,为表格识别功能接入了[RapidTable](https://github.com/RapidAI/RapidTable),单表解析速度提升10倍以上,准确率更高,显存占用更低
- 2024/11/06 0.9.2发布,为表格识别功能接入了[StructTable-InternVL2-1B](https://huggingface.co/U4R/StructTable-InternVL2-1B)模型 - 2024/11/06 0.9.2发布,为表格识别功能接入了[StructTable-InternVL2-1B](https://huggingface.co/U4R/StructTable-InternVL2-1B)模型
- 2024/10/31 0.9.0发布,这是我们进行了大量代码重构的全新版本,解决了众多问题,提升了性能,降低了硬件需求,并提供了更丰富的易用性: - 2024/10/31 0.9.0发布,这是我们进行了大量代码重构的全新版本,解决了众多问题,提升了性能,降低了硬件需求,并提供了更丰富的易用性:
- 重构排序模块代码,使用 [layoutreader](https://github.com/ppaanngggg/layoutreader) 进行阅读顺序排序,确保在各种排版下都能实现极高准确率 - 重构排序模块代码,使用 [layoutreader](https://github.com/ppaanngggg/layoutreader) 进行阅读顺序排序,确保在各种排版下都能实现极高准确率
...@@ -188,13 +188,13 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c ...@@ -188,13 +188,13 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
<td rowspan="2">GPU硬件支持列表</td> <td rowspan="2">GPU硬件支持列表</td>
<td colspan="2">最低要求 8G+显存</td> <td colspan="2">最低要求 8G+显存</td>
<td colspan="2">3060ti/3070/4060<br> <td colspan="2">3060ti/3070/4060<br>
8G显存可开启layout、公式识别和ocr加速</td> 8G显存可开启全部加速功能(表格仅限rapid_table)</td>
<td rowspan="2">None</td> <td rowspan="2">None</td>
</tr> </tr>
<tr> <tr>
<td colspan="2">推荐配置 10G+显存</td> <td colspan="2">推荐配置 10G+显存</td>
<td colspan="2">3080/3080ti/3090/3090ti/4070/4070ti/4070tisuper/4080/4090<br> <td colspan="2">3080/3080ti/3090/3090ti/4070/4070ti/4070tisuper/4080/4090<br>
10G显存及以上可以同时开启layout、公式识别和ocr加速和表格识别加速<br> 10G显存及以上可开启全部加速功能<br>
</td> </td>
</tr> </tr>
</table> </table>
...@@ -251,7 +251,7 @@ pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i h ...@@ -251,7 +251,7 @@ pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i h
"enable": true // 公式识别功能默认是开启的,如果需要关闭请修改此处的值为"false" "enable": true // 公式识别功能默认是开启的,如果需要关闭请修改此处的值为"false"
}, },
"table-config": { "table-config": {
"model": "tablemaster", // 使用structEqTable请修改为"struct_eqtable" "model": "rapid_table", // 使用structEqTable请修改为"struct_eqtable"
"enable": false, // 表格识别功能默认是关闭的,如果需要开启请修改此处的值为"true" "enable": false, // 表格识别功能默认是关闭的,如果需要开启请修改此处的值为"true"
"max_time": 400 "max_time": 400
} }
...@@ -266,7 +266,7 @@ pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i h ...@@ -266,7 +266,7 @@ pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i h
- [Windows10/11 + GPU](docs/README_Windows_CUDA_Acceleration_zh_CN.md) - [Windows10/11 + GPU](docs/README_Windows_CUDA_Acceleration_zh_CN.md)
- 使用Docker快速部署 - 使用Docker快速部署
> [!IMPORTANT] > [!IMPORTANT]
> Docker 需设备gpu显存大于等于16GB,默认开启所有加速功能 > Docker 需设备gpu显存大于等于8GB,默认开启所有加速功能
> >
> 运行本docker前可以通过以下命令检测自己的设备是否支持在docker上使用CUDA加速 > 运行本docker前可以通过以下命令检测自己的设备是否支持在docker上使用CUDA加速
> >
...@@ -431,6 +431,7 @@ TODO ...@@ -431,6 +431,7 @@ TODO
- [PDF-Extract-Kit](https://github.com/opendatalab/PDF-Extract-Kit) - [PDF-Extract-Kit](https://github.com/opendatalab/PDF-Extract-Kit)
- [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO) - [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO)
- [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy) - [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
- [RapidTable](https://github.com/RapidAI/RapidTable)
- [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) - [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
- [PyMuPDF](https://github.com/pymupdf/PyMuPDF) - [PyMuPDF](https://github.com/pymupdf/PyMuPDF)
- [layoutreader](https://github.com/ppaanngggg/layoutreader) - [layoutreader](https://github.com/ppaanngggg/layoutreader)
......
...@@ -19,9 +19,10 @@ def json_md_dump( ...@@ -19,9 +19,10 @@ def json_md_dump(
pdf_name, pdf_name,
content_list, content_list,
md_content, md_content,
orig_model_list,
): ):
# 写入模型结果到 model.json # 写入模型结果到 model.json
orig_model_list = copy.deepcopy(pipe.model_list)
md_writer.write( md_writer.write(
content=json.dumps(orig_model_list, ensure_ascii=False, indent=4), content=json.dumps(orig_model_list, ensure_ascii=False, indent=4),
path=f"{pdf_name}_model.json" path=f"{pdf_name}_model.json"
...@@ -87,9 +88,12 @@ def pdf_parse_main( ...@@ -87,9 +88,12 @@ def pdf_parse_main(
pdf_bytes = open(pdf_path, "rb").read() # 读取 pdf 文件的二进制数据 pdf_bytes = open(pdf_path, "rb").read() # 读取 pdf 文件的二进制数据
orig_model_list = []
if model_json_path: if model_json_path:
# 读取已经被模型解析后的pdf文件的 json 原始数据,list 类型 # 读取已经被模型解析后的pdf文件的 json 原始数据,list 类型
model_json = json.loads(open(model_json_path, "r", encoding="utf-8").read()) model_json = json.loads(open(model_json_path, "r", encoding="utf-8").read())
orig_model_list = copy.deepcopy(model_json)
else: else:
model_json = [] model_json = []
...@@ -115,8 +119,9 @@ def pdf_parse_main( ...@@ -115,8 +119,9 @@ def pdf_parse_main(
pipe.pipe_classify() pipe.pipe_classify()
# 如果没有传入模型数据,则使用内置模型解析 # 如果没有传入模型数据,则使用内置模型解析
if not model_json: if len(model_json) == 0:
pipe.pipe_analyze() # 解析 pipe.pipe_analyze() # 解析
orig_model_list = copy.deepcopy(pipe.model_list)
# 执行解析 # 执行解析
pipe.pipe_parse() pipe.pipe_parse()
...@@ -126,7 +131,7 @@ def pdf_parse_main( ...@@ -126,7 +131,7 @@ def pdf_parse_main(
md_content = pipe.pipe_mk_markdown(image_path_parent, drop_mode="none") md_content = pipe.pipe_mk_markdown(image_path_parent, drop_mode="none")
if is_json_md_dump: if is_json_md_dump:
json_md_dump(pipe, md_writer, pdf_name, content_list, md_content) json_md_dump(pipe, md_writer, pdf_name, content_list, md_content, orig_model_list)
if is_draw_visualization_bbox: if is_draw_visualization_bbox:
draw_visualization_bbox(pipe.pdf_mid_data['pdf_info'], pdf_bytes, output_path, pdf_name) draw_visualization_bbox(pipe.pdf_mid_data['pdf_info'], pdf_bytes, output_path, pdf_name)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"enable": true "enable": true
}, },
"table-config": { "table-config": {
"model": "tablemaster", "model": "rapid_table",
"enable": false, "enable": false,
"max_time": 400 "max_time": 400
}, },
......
...@@ -168,7 +168,7 @@ def merge_para_with_text(para_block): ...@@ -168,7 +168,7 @@ def merge_para_with_text(para_block):
# 如果是前一行带有-连字符,那么末尾不应该加空格 # 如果是前一行带有-连字符,那么末尾不应该加空格
if __is_hyphen_at_line_end(content): if __is_hyphen_at_line_end(content):
para_text += content[:-1] para_text += content[:-1]
elif len(content) == 1 and content not in ['A', 'I', 'a', 'i']: elif len(content) == 1 and content not in ['A', 'I', 'a', 'i'] and not content.isdigit():
para_text += content para_text += content
else: # 西方文本语境下 content间需要空格分隔 else: # 西方文本语境下 content间需要空格分隔
para_text += f"{content} " para_text += f"{content} "
......
...@@ -50,4 +50,6 @@ class MODEL_NAME: ...@@ -50,4 +50,6 @@ class MODEL_NAME:
YOLO_V8_MFD = "yolo_v8_mfd" YOLO_V8_MFD = "yolo_v8_mfd"
UniMerNet_v2_Small = "unimernet_small" UniMerNet_v2_Small = "unimernet_small"
\ No newline at end of file
RAPID_TABLE = "rapid_table"
\ No newline at end of file
...@@ -92,7 +92,7 @@ def get_table_recog_config(): ...@@ -92,7 +92,7 @@ def get_table_recog_config():
table_config = config.get('table-config') table_config = config.get('table-config')
if table_config is None: if table_config is None:
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default") logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}') return json.loads(f'{{"model": "{MODEL_NAME.RAPID_TABLE}","enable": false, "max_time": 400}}')
else: else:
return table_config return table_config
......
...@@ -369,10 +369,16 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename): ...@@ -369,10 +369,16 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
if block['type'] in [BlockType.Image, BlockType.Table]: if block['type'] in [BlockType.Image, BlockType.Table]:
for sub_block in block['blocks']: for sub_block in block['blocks']:
if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]: if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
for line in sub_block['virtual_lines']: if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
bbox = line['bbox'] for line in sub_block['virtual_lines']:
index = line['index'] bbox = line['bbox']
page_line_list.append({'index': index, 'bbox': bbox}) index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
else:
for line in sub_block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]: elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]:
for line in sub_block['lines']: for line in sub_block['lines']:
bbox = line['bbox'] bbox = line['bbox']
......
import numpy as np
import torch
from loguru import logger from loguru import logger
import os import os
import time import time
from pathlib import Path import cv2
import shutil import yaml
from magic_pdf.libs.Constants import * from PIL import Image
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.model.model_list import AtomicModel
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try: try:
import cv2
import yaml
import argparse
import numpy as np
import torch
import torchtext import torchtext
if torchtext.__version__ >= "0.18.0": if torchtext.__version__ >= "0.18.0":
torchtext.disable_torchtext_deprecation_warning() torchtext.disable_torchtext_deprecation_warning()
from PIL import Image except ImportError:
from torchvision import transforms pass
from torch.utils.data import Dataset, DataLoader
from ultralytics import YOLO
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor
from doclayout_yolo import YOLOv10
except ImportError as e:
logger.exception(e)
logger.error(
'Required dependency not installed, please install by \n'
'"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
exit(1)
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
from magic_pdf.model.ppTableModel import ppTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER:
config = {
"model_dir": model_path,
"device": _device_
}
table_model = ppTableModel(config)
else:
logger.error("table model type not allow")
exit(1)
return table_model
def mfd_model_init(weight):
mfd_model = YOLO(weight)
return mfd_model
def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
args = argparse.Namespace(cfg_path=cfg_path, options=None)
cfg = Config(args)
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
cfg.config.model.model_config.model_name = weight_dir
cfg.config.model.tokenizer_config.path = weight_dir
task = tasks.setup_task(cfg)
model = task.build_model(cfg)
model.to(_device_)
model.eval()
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
mfr_transform = transforms.Compose([vis_processor, ])
return [model, mfr_transform]
def layout_model_init(weight, config_file, device):
model = Layoutlmv3_Predictor(weight, config_file, device)
return model
def doclayout_yolo_model_init(weight):
model = YOLOv10(weight)
return model
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
if lang is not None:
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
else:
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
return model
class MathDataset(Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# if not pil image, then convert to pil image
if isinstance(self.image_paths[idx], str):
raw_image = Image.open(self.image_paths[idx])
else:
raw_image = self.image_paths[idx]
if self.transform:
image = self.transform(raw_image)
return image
class AtomModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs):
lang = kwargs.get("lang", None)
layout_model_name = kwargs.get("layout_model_name", None)
key = (atom_model_name, layout_model_name, lang)
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
return self._models[key]
def atom_model_init(model_name: str, **kwargs):
if model_name == AtomicModel.Layout:
if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
atom_model = layout_model_init(
kwargs.get("layout_weights"),
kwargs.get("layout_config_file"),
kwargs.get("device")
)
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
atom_model = doclayout_yolo_model_init(
kwargs.get("doclayout_yolo_weights"),
)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get("mfd_weights")
)
elif model_name == AtomicModel.MFR:
atom_model = mfr_model_init(
kwargs.get("mfr_weight_dir"),
kwargs.get("mfr_cfg_path"),
kwargs.get("device")
)
elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init(
kwargs.get("ocr_show_log"),
kwargs.get("det_db_box_thresh"),
kwargs.get("lang")
)
elif model_name == AtomicModel.Table:
atom_model = table_model_init(
kwargs.get("table_model_name"),
kwargs.get("table_model_path"),
kwargs.get("table_max_time"),
kwargs.get("device")
)
else:
logger.error("model name not allow")
exit(1)
return atom_model
# Unified crop img logic from magic_pdf.libs.Constants import *
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0): from magic_pdf.model.model_list import AtomicModel
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1]) from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5]) from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram
# Create a white background with an additional width and height of 50 from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
# Crop image
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
cropped_img = input_pil_img.crop(crop_box)
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
return return_image, return_list
class CustomPEKModel: class CustomPEKModel:
...@@ -226,7 +59,7 @@ class CustomPEKModel: ...@@ -226,7 +59,7 @@ class CustomPEKModel:
self.table_config = kwargs.get("table_config") self.table_config = kwargs.get("table_config")
self.apply_table = self.table_config.get("enable", False) self.apply_table = self.table_config.get("enable", False)
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE) self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER) self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE)
# ocr config # ocr config
self.apply_ocr = ocr self.apply_ocr = ocr
...@@ -235,7 +68,8 @@ class CustomPEKModel: ...@@ -235,7 +68,8 @@ class CustomPEKModel:
logger.info( logger.info(
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, " "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
"apply_table: {}, table_model: {}, lang: {}".format( "apply_table: {}, table_model: {}, lang: {}".format(
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
self.lang
) )
) )
# 初始化解析方案 # 初始化解析方案
...@@ -248,17 +82,17 @@ class CustomPEKModel: ...@@ -248,17 +82,17 @@ class CustomPEKModel:
# 初始化公式识别 # 初始化公式识别
if self.apply_formula: if self.apply_formula:
# 初始化公式检测模型 # 初始化公式检测模型
self.mfd_model = atom_model_manager.get_atom_model( self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD, atom_model_name=AtomicModel.MFD,
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])) mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])),
device=self.device
) )
# 初始化公式解析模型 # 初始化公式解析模型
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name])) mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml")) mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model( self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR, atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir, mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path, mfr_cfg_path=mfr_cfg_path,
...@@ -278,7 +112,8 @@ class CustomPEKModel: ...@@ -278,7 +112,8 @@ class CustomPEKModel:
self.layout_model = atom_model_manager.get_atom_model( self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout, atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.DocLayout_YOLO, layout_model_name=MODEL_NAME.DocLayout_YOLO,
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])) doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
device=self.device
) )
# 初始化ocr # 初始化ocr
if self.apply_ocr: if self.apply_ocr:
...@@ -305,26 +140,15 @@ class CustomPEKModel: ...@@ -305,26 +140,15 @@ class CustomPEKModel:
page_start = time.time() page_start = time.time()
latex_filling_list = []
mf_image_list = []
# layout检测 # layout检测
layout_start = time.time() layout_start = time.time()
layout_res = []
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3: if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3 # layoutlmv3
layout_res = self.layout_model(image, ignore_catids=[]) layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO: elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo # doclayout_yolo
layout_res = [] layout_res = self.layout_model.predict(image)
doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 3),
}
layout_res.append(new_item)
layout_cost = round(time.time() - layout_start, 2) layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection time: {layout_cost}") logger.info(f"layout detection time: {layout_cost}")
...@@ -333,59 +157,21 @@ class CustomPEKModel: ...@@ -333,59 +157,21 @@ class CustomPEKModel:
if self.apply_formula: if self.apply_formula:
# 公式检测 # 公式检测
mfd_start = time.time() mfd_start = time.time()
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0] mfd_res = self.mfd_model.predict(image)
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}") logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': 13 + int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 2),
'latex': '',
}
layout_res.append(new_item)
latex_filling_list.append(new_item)
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
mf_image_list.append(bbox_img)
# 公式识别 # 公式识别
mfr_start = time.time() mfr_start = time.time()
dataset = MathDataset(mf_image_list, transform=self.mfr_transform) formula_list = self.mfr_model.predict(mfd_res, image)
dataloader = DataLoader(dataset, batch_size=64, num_workers=0) layout_res.extend(formula_list)
mfr_res = []
for mf_img in dataloader:
mf_img = mf_img.to(self.device)
with torch.no_grad():
output = self.mfr_model.generate({'image': mf_img})
mfr_res.extend(output['pred_str'])
for res, latex in zip(latex_filling_list, mfr_res):
res['latex'] = latex_rm_whitespace(latex)
mfr_cost = round(time.time() - mfr_start, 2) mfr_cost = round(time.time() - mfr_start, 2)
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}") logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}")
# Select regions for OCR / formula regions / table regions # 清理显存
ocr_res_list = [] clean_vram(self.device, vram_threshold=8)
table_res_list = []
single_page_mfdetrec_res = [] # 从layout_res中获取ocr区域、表格区域、公式区域
for res in layout_res: ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res)
if int(res['category_id']) in [13, 14]:
single_page_mfdetrec_res.append({
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
int(res['poly'][4]), int(res['poly'][5])],
})
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
ocr_res_list.append(res)
elif int(res['category_id']) in [5]:
table_res_list.append(res)
if torch.cuda.is_available() and self.device != 'cpu':
properties = torch.cuda.get_device_properties(self.device)
total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
if total_memory <= 10:
gc_start = time.time()
clean_memory()
gc_time = round(time.time() - gc_start, 2)
logger.info(f"gc time: {gc_time}")
# ocr识别 # ocr识别
if self.apply_ocr: if self.apply_ocr:
...@@ -393,23 +179,7 @@ class CustomPEKModel: ...@@ -393,23 +179,7 @@ class CustomPEKModel:
# Process each area that requires OCR processing # Process each area that requires OCR processing
for res in ocr_res_list: for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50) new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
# Adjust the coordinates of the formula area
adjusted_mfdetrec_res = []
for mf_res in single_page_mfdetrec_res:
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
x0 = mf_xmin - xmin + paste_x
y0 = mf_ymin - ymin + paste_y
x1 = mf_xmax - xmin + paste_x
y1 = mf_ymax - ymin + paste_y
# Filter formula blocks outside the graph
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
continue
else:
adjusted_mfdetrec_res.append({
"bbox": [x0, y0, x1, y1],
})
# OCR recognition # OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
...@@ -417,22 +187,8 @@ class CustomPEKModel: ...@@ -417,22 +187,8 @@ class CustomPEKModel:
# Integration results # Integration results
if ocr_res: if ocr_res:
for box_ocr_res in ocr_res: ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
p1, p2, p3, p4 = box_ocr_res[0] layout_res.extend(ocr_result_list)
text, score = box_ocr_res[1]
# Convert the coordinates back to the original coordinate system
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
layout_res.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': round(score, 2),
'text': text,
})
ocr_cost = round(time.time() - ocr_start, 2) ocr_cost = round(time.time() - ocr_start, 2)
logger.info(f"ocr time: {ocr_cost}") logger.info(f"ocr time: {ocr_cost}")
...@@ -443,41 +199,30 @@ class CustomPEKModel: ...@@ -443,41 +199,30 @@ class CustomPEKModel:
for res in table_res_list: for res in table_res_list:
new_image, _ = crop_img(res, pil_img) new_image, _ = crop_img(res, pil_img)
single_table_start_time = time.time() single_table_start_time = time.time()
# logger.info("------------------table recognition processing begins-----------------")
latex_code = None
html_code = None html_code = None
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE: if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad(): with torch.no_grad():
table_result = self.table_model.predict(new_image, "html") table_result = self.table_model.predict(new_image, "html")
if len(table_result) > 0: if len(table_result) > 0:
html_code = table_result[0] html_code = table_result[0]
else: elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image) html_code = self.table_model.img2html(new_image)
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image)
run_time = time.time() - single_table_start_time run_time = time.time() - single_table_start_time
# logger.info(f"------------table recognition processing ends within {run_time}s-----")
if run_time > self.table_max_time: if run_time > self.table_max_time:
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------") logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s")
# 判断是否返回正常 # 判断是否返回正常
if html_code:
if latex_code:
expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
if expected_ending:
res["latex"] = latex_code
else:
logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
elif html_code:
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>') expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
if expected_ending: if expected_ending:
res["html"] = html_code res["html"] = html_code
else: else:
logger.warning(f"table recognition processing fails, not found expected HTML table end") logger.warning(f"table recognition processing fails, not found expected HTML table end")
else: else:
logger.warning(f"table recognition processing fails, not get latex or html return") logger.warning(f"table recognition processing fails, not get html return")
logger.info(f"table time: {round(time.time() - table_start, 2)}") logger.info(f"table time: {round(time.time() - table_start, 2)}")
logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----") logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
return layout_res return layout_res
import re
def layout_rm_equation(layout_res):
rm_idxs = []
for idx, ele in enumerate(layout_res['layout_dets']):
if ele['category_id'] == 10:
rm_idxs.append(idx)
for idx in rm_idxs[::-1]:
del layout_res['layout_dets'][idx]
return layout_res
def get_croped_image(image_pil, bbox):
x_min, y_min, x_max, y_max = bbox
croped_img = image_pil.crop((x_min, y_min, x_max, y_max))
return croped_img
def latex_rm_whitespace(s: str):
"""Remove unnecessary whitespace from LaTeX code.
"""
text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
letter = '[a-zA-Z]'
noletter = '[\W_^\d]'
names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
news = s
while True:
s = news
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
if news == s:
break
return s
\ No newline at end of file
from doclayout_yolo import YOLOv10
class DocLayoutYOLOModel(object):
def __init__(self, weight, device):
self.model = YOLOv10(weight)
self.device = device
def predict(self, image):
layout_res = []
doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(),
doclayout_yolo_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 3),
}
layout_res.append(new_item)
return layout_res
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment