Unverified Commit 845a3ff0 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #969 from opendatalab/release-0.9.3

Release 0.9.3
parents d0558abb 6083e109
......@@ -48,3 +48,6 @@ debug_utils/
# sphinx docs
_build/
output/
\ No newline at end of file
......@@ -42,6 +42,7 @@
</div>
# Changelog
- 2024/11/15 0.9.3 released. Integrated [RapidTable](https://github.com/RapidAI/RapidTable) for table recognition, improving single-table parsing speed by more than 10 times, with higher accuracy and lower GPU memory usage.
- 2024/11/06 0.9.2 released. Integrated the [StructTable-InternVL2-1B](https://huggingface.co/U4R/StructTable-InternVL2-1B) model for table recognition functionality.
- 2024/10/31 0.9.0 released. This is a major new version with extensive code refactoring, addressing numerous issues, improving performance, reducing hardware requirements, and enhancing usability:
- Refactored the sorting module code to use [layoutreader](https://github.com/ppaanngggg/layoutreader) for reading order sorting, ensuring high accuracy in various layouts.
......@@ -246,7 +247,7 @@ You can modify certain configurations in this file to enable or disable features
"enable": true // The formula recognition feature is enabled by default. If you need to disable it, please change the value here to "false".
},
"table-config": {
"model": "tablemaster", // When using structEqTable, please change to "struct_eqtable".
"model": "rapid_table", // When using structEqTable, please change to "struct_eqtable".
"enable": false, // The table recognition feature is disabled by default. If you need to enable it, please change the value here to "true".
"max_time": 400
}
......@@ -261,7 +262,7 @@ If your device supports CUDA and meets the GPU requirements of the mainline envi
- [Windows 10/11 + GPU](docs/README_Windows_CUDA_Acceleration_en_US.md)
- Quick Deployment with Docker
> [!IMPORTANT]
> Docker requires a GPU with at least 16GB of VRAM, and all acceleration features are enabled by default.
> Docker requires a GPU with at least 8GB of VRAM, and all acceleration features are enabled by default.
>
> Before running this Docker, you can use the following command to check if your device supports CUDA acceleration on Docker.
>
......@@ -421,7 +422,9 @@ This project currently uses PyMuPDF to achieve advanced functionality. However,
# Acknowledgments
- [PDF-Extract-Kit](https://github.com/opendatalab/PDF-Extract-Kit)
- [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO)
- [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
- [RapidTable](https://github.com/RapidAI/RapidTable)
- [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
- [PyMuPDF](https://github.com/pymupdf/PyMuPDF)
- [layoutreader](https://github.com/ppaanngggg/layoutreader)
......
> [!Warning]
> このドキュメントはすでに古くなっています。最新版のドキュメントを参照してください:[ENGLISH](README.md)。
<div id="top">
<p align="center">
......@@ -18,9 +20,7 @@
<a href="https://trendshift.io/repositories/11174" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11174" alt="opendatalab%2FMinerU | Trendshift" style="width: 200px; height: 55px;"/></a>
<div align="center" style="color: red; background-color: #ffdddd; padding: 10px; border: 1px solid red; border-radius: 5px;">
<strong>NOTE:</strong> このドキュメントはすでに古くなっています。最新版のドキュメントを参照してください。
</div>
[English](README.md) | [简体中文](README_zh-CN.md) | [日本語](README_ja-JP.md)
......
......@@ -42,7 +42,7 @@
</div>
# 更新记录
- 2024/11/15 0.9.3发布,为表格识别功能接入了[RapidTable](https://github.com/RapidAI/RapidTable),单表解析速度提升10倍以上,准确率更高,显存占用更低
- 2024/11/06 0.9.2发布,为表格识别功能接入了[StructTable-InternVL2-1B](https://huggingface.co/U4R/StructTable-InternVL2-1B)模型
- 2024/10/31 0.9.0发布,这是我们进行了大量代码重构的全新版本,解决了众多问题,提升了性能,降低了硬件需求,并提供了更丰富的易用性:
- 重构排序模块代码,使用 [layoutreader](https://github.com/ppaanngggg/layoutreader) 进行阅读顺序排序,确保在各种排版下都能实现极高准确率
......@@ -188,13 +188,13 @@ https://github.com/user-attachments/assets/4bea02c9-6d54-4cd6-97ed-dff14340982c
<td rowspan="2">GPU硬件支持列表</td>
<td colspan="2">最低要求 8G+显存</td>
<td colspan="2">3060ti/3070/4060<br>
8G显存可开启layout、公式识别和ocr加速</td>
8G显存可开启全部加速功能(表格仅限rapid_table)</td>
<td rowspan="2">None</td>
</tr>
<tr>
<td colspan="2">推荐配置 10G+显存</td>
<td colspan="2">3080/3080ti/3090/3090ti/4070/4070ti/4070tisuper/4080/4090<br>
10G显存及以上可以同时开启layout、公式识别和ocr加速和表格识别加速<br>
10G显存及以上可开启全部加速功能<br>
</td>
</tr>
</table>
......@@ -251,7 +251,7 @@ pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i h
"enable": true // 公式识别功能默认是开启的,如果需要关闭请修改此处的值为"false"
},
"table-config": {
"model": "tablemaster", // 使用structEqTable请修改为"struct_eqtable"
"model": "rapid_table", // 使用structEqTable请修改为"struct_eqtable"
"enable": false, // 表格识别功能默认是关闭的,如果需要开启请修改此处的值为"true"
"max_time": 400
}
......@@ -266,7 +266,7 @@ pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i h
- [Windows10/11 + GPU](docs/README_Windows_CUDA_Acceleration_zh_CN.md)
- 使用Docker快速部署
> [!IMPORTANT]
> Docker 需设备gpu显存大于等于16GB,默认开启所有加速功能
> Docker 需设备gpu显存大于等于8GB,默认开启所有加速功能
>
> 运行本docker前可以通过以下命令检测自己的设备是否支持在docker上使用CUDA加速
>
......@@ -431,6 +431,7 @@ TODO
- [PDF-Extract-Kit](https://github.com/opendatalab/PDF-Extract-Kit)
- [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO)
- [StructEqTable](https://github.com/UniModal4Reasoning/StructEqTable-Deploy)
- [RapidTable](https://github.com/RapidAI/RapidTable)
- [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)
- [PyMuPDF](https://github.com/pymupdf/PyMuPDF)
- [layoutreader](https://github.com/ppaanngggg/layoutreader)
......
......@@ -19,9 +19,10 @@ def json_md_dump(
pdf_name,
content_list,
md_content,
orig_model_list,
):
# 写入模型结果到 model.json
orig_model_list = copy.deepcopy(pipe.model_list)
md_writer.write(
content=json.dumps(orig_model_list, ensure_ascii=False, indent=4),
path=f"{pdf_name}_model.json"
......@@ -87,9 +88,12 @@ def pdf_parse_main(
pdf_bytes = open(pdf_path, "rb").read() # 读取 pdf 文件的二进制数据
orig_model_list = []
if model_json_path:
# 读取已经被模型解析后的pdf文件的 json 原始数据,list 类型
model_json = json.loads(open(model_json_path, "r", encoding="utf-8").read())
orig_model_list = copy.deepcopy(model_json)
else:
model_json = []
......@@ -115,8 +119,9 @@ def pdf_parse_main(
pipe.pipe_classify()
# 如果没有传入模型数据,则使用内置模型解析
if not model_json:
if len(model_json) == 0:
pipe.pipe_analyze() # 解析
orig_model_list = copy.deepcopy(pipe.model_list)
# 执行解析
pipe.pipe_parse()
......@@ -126,7 +131,7 @@ def pdf_parse_main(
md_content = pipe.pipe_mk_markdown(image_path_parent, drop_mode="none")
if is_json_md_dump:
json_md_dump(pipe, md_writer, pdf_name, content_list, md_content)
json_md_dump(pipe, md_writer, pdf_name, content_list, md_content, orig_model_list)
if is_draw_visualization_bbox:
draw_visualization_bbox(pipe.pdf_mid_data['pdf_info'], pdf_bytes, output_path, pdf_name)
......
......@@ -15,7 +15,7 @@
"enable": true
},
"table-config": {
"model": "tablemaster",
"model": "rapid_table",
"enable": false,
"max_time": 400
},
......
......@@ -168,7 +168,7 @@ def merge_para_with_text(para_block):
# 如果是前一行带有-连字符,那么末尾不应该加空格
if __is_hyphen_at_line_end(content):
para_text += content[:-1]
elif len(content) == 1 and content not in ['A', 'I', 'a', 'i']:
elif len(content) == 1 and content not in ['A', 'I', 'a', 'i'] and not content.isdigit():
para_text += content
else: # 西方文本语境下 content间需要空格分隔
para_text += f"{content} "
......
......@@ -50,4 +50,6 @@ class MODEL_NAME:
YOLO_V8_MFD = "yolo_v8_mfd"
UniMerNet_v2_Small = "unimernet_small"
\ No newline at end of file
UniMerNet_v2_Small = "unimernet_small"
RAPID_TABLE = "rapid_table"
\ No newline at end of file
......@@ -92,7 +92,7 @@ def get_table_recog_config():
table_config = config.get('table-config')
if table_config is None:
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
return json.loads(f'{{"model": "{MODEL_NAME.RAPID_TABLE}","enable": false, "max_time": 400}}')
else:
return table_config
......
......@@ -369,10 +369,16 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
if block['type'] in [BlockType.Image, BlockType.Table]:
for sub_block in block['blocks']:
if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
for line in sub_block['virtual_lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:
for line in sub_block['virtual_lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
else:
for line in sub_block['lines']:
bbox = line['bbox']
index = line['index']
page_line_list.append({'index': index, 'bbox': bbox})
elif sub_block['type'] in [BlockType.ImageCaption, BlockType.TableCaption, BlockType.ImageFootnote, BlockType.TableFootnote]:
for line in sub_block['lines']:
bbox = line['bbox']
......
import numpy as np
import torch
from loguru import logger
import os
import time
from pathlib import Path
import shutil
from magic_pdf.libs.Constants import *
from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.model.model_list import AtomicModel
import cv2
import yaml
from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try:
import cv2
import yaml
import argparse
import numpy as np
import torch
import torchtext
if torchtext.__version__ >= "0.18.0":
torchtext.disable_torchtext_deprecation_warning()
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from ultralytics import YOLO
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor
from doclayout_yolo import YOLOv10
except ImportError as e:
logger.exception(e)
logger.error(
'Required dependency not installed, please install by \n'
'"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
exit(1)
from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
from magic_pdf.model.ppTableModel import ppTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER:
config = {
"model_dir": model_path,
"device": _device_
}
table_model = ppTableModel(config)
else:
logger.error("table model type not allow")
exit(1)
return table_model
def mfd_model_init(weight):
mfd_model = YOLO(weight)
return mfd_model
def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
args = argparse.Namespace(cfg_path=cfg_path, options=None)
cfg = Config(args)
cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
cfg.config.model.model_config.model_name = weight_dir
cfg.config.model.tokenizer_config.path = weight_dir
task = tasks.setup_task(cfg)
model = task.build_model(cfg)
model.to(_device_)
model.eval()
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
mfr_transform = transforms.Compose([vis_processor, ])
return [model, mfr_transform]
def layout_model_init(weight, config_file, device):
model = Layoutlmv3_Predictor(weight, config_file, device)
return model
def doclayout_yolo_model_init(weight):
model = YOLOv10(weight)
return model
def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
if lang is not None:
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
else:
model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
return model
class MathDataset(Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# if not pil image, then convert to pil image
if isinstance(self.image_paths[idx], str):
raw_image = Image.open(self.image_paths[idx])
else:
raw_image = self.image_paths[idx]
if self.transform:
image = self.transform(raw_image)
return image
class AtomModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs):
lang = kwargs.get("lang", None)
layout_model_name = kwargs.get("layout_model_name", None)
key = (atom_model_name, layout_model_name, lang)
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
return self._models[key]
def atom_model_init(model_name: str, **kwargs):
if model_name == AtomicModel.Layout:
if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
atom_model = layout_model_init(
kwargs.get("layout_weights"),
kwargs.get("layout_config_file"),
kwargs.get("device")
)
elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
atom_model = doclayout_yolo_model_init(
kwargs.get("doclayout_yolo_weights"),
)
elif model_name == AtomicModel.MFD:
atom_model = mfd_model_init(
kwargs.get("mfd_weights")
)
elif model_name == AtomicModel.MFR:
atom_model = mfr_model_init(
kwargs.get("mfr_weight_dir"),
kwargs.get("mfr_cfg_path"),
kwargs.get("device")
)
elif model_name == AtomicModel.OCR:
atom_model = ocr_model_init(
kwargs.get("ocr_show_log"),
kwargs.get("det_db_box_thresh"),
kwargs.get("lang")
)
elif model_name == AtomicModel.Table:
atom_model = table_model_init(
kwargs.get("table_model_name"),
kwargs.get("table_model_path"),
kwargs.get("table_max_time"),
kwargs.get("device")
)
else:
logger.error("model name not allow")
exit(1)
return atom_model
except ImportError:
pass
# Unified crop img logic
def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
# Create a white background with an additional width and height of 50
crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
# Crop image
crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
cropped_img = input_pil_img.crop(crop_box)
return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
return return_image, return_list
from magic_pdf.libs.Constants import *
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
class CustomPEKModel:
......@@ -226,7 +59,7 @@ class CustomPEKModel:
self.table_config = kwargs.get("table_config")
self.apply_table = self.table_config.get("enable", False)
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE)
# ocr config
self.apply_ocr = ocr
......@@ -235,7 +68,8 @@ class CustomPEKModel:
logger.info(
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
"apply_table: {}, table_model: {}, lang: {}".format(
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
self.lang
)
)
# 初始化解析方案
......@@ -248,17 +82,17 @@ class CustomPEKModel:
# 初始化公式识别
if self.apply_formula:
# 初始化公式检测模型
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])),
device=self.device
)
# 初始化公式解析模型
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path,
......@@ -278,7 +112,8 @@ class CustomPEKModel:
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.DocLayout_YOLO,
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
device=self.device
)
# 初始化ocr
if self.apply_ocr:
......@@ -305,26 +140,15 @@ class CustomPEKModel:
page_start = time.time()
latex_filling_list = []
mf_image_list = []
# layout检测
layout_start = time.time()
layout_res = []
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3
layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
layout_res = []
doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 3),
}
layout_res.append(new_item)
layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection time: {layout_cost}")
......@@ -333,59 +157,21 @@ class CustomPEKModel:
if self.apply_formula:
# 公式检测
mfd_start = time.time()
mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
mfd_res = self.mfd_model.predict(image)
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': 13 + int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 2),
'latex': '',
}
layout_res.append(new_item)
latex_filling_list.append(new_item)
bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
mf_image_list.append(bbox_img)
# 公式识别
mfr_start = time.time()
dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
mfr_res = []
for mf_img in dataloader:
mf_img = mf_img.to(self.device)
with torch.no_grad():
output = self.mfr_model.generate({'image': mf_img})
mfr_res.extend(output['pred_str'])
for res, latex in zip(latex_filling_list, mfr_res):
res['latex'] = latex_rm_whitespace(latex)
formula_list = self.mfr_model.predict(mfd_res, image)
layout_res.extend(formula_list)
mfr_cost = round(time.time() - mfr_start, 2)
logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
# Select regions for OCR / formula regions / table regions
ocr_res_list = []
table_res_list = []
single_page_mfdetrec_res = []
for res in layout_res:
if int(res['category_id']) in [13, 14]:
single_page_mfdetrec_res.append({
"bbox": [int(res['poly'][0]), int(res['poly'][1]),
int(res['poly'][4]), int(res['poly'][5])],
})
elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
ocr_res_list.append(res)
elif int(res['category_id']) in [5]:
table_res_list.append(res)
if torch.cuda.is_available() and self.device != 'cpu':
properties = torch.cuda.get_device_properties(self.device)
total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
if total_memory <= 10:
gc_start = time.time()
clean_memory()
gc_time = round(time.time() - gc_start, 2)
logger.info(f"gc time: {gc_time}")
logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}")
# 清理显存
clean_vram(self.device, vram_threshold=8)
# 从layout_res中获取ocr区域、表格区域、公式区域
ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res)
# ocr识别
if self.apply_ocr:
......@@ -393,23 +179,7 @@ class CustomPEKModel:
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
# Adjust the coordinates of the formula area
adjusted_mfdetrec_res = []
for mf_res in single_page_mfdetrec_res:
mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
# Adjust the coordinates of the formula area to the coordinates relative to the cropping area
x0 = mf_xmin - xmin + paste_x
y0 = mf_ymin - ymin + paste_y
x1 = mf_xmax - xmin + paste_x
y1 = mf_ymax - ymin + paste_y
# Filter formula blocks outside the graph
if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
continue
else:
adjusted_mfdetrec_res.append({
"bbox": [x0, y0, x1, y1],
})
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
# OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
......@@ -417,22 +187,8 @@ class CustomPEKModel:
# Integration results
if ocr_res:
for box_ocr_res in ocr_res:
p1, p2, p3, p4 = box_ocr_res[0]
text, score = box_ocr_res[1]
# Convert the coordinates back to the original coordinate system
p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
layout_res.append({
'category_id': 15,
'poly': p1 + p2 + p3 + p4,
'score': round(score, 2),
'text': text,
})
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
layout_res.extend(ocr_result_list)
ocr_cost = round(time.time() - ocr_start, 2)
logger.info(f"ocr time: {ocr_cost}")
......@@ -443,41 +199,30 @@ class CustomPEKModel:
for res in table_res_list:
new_image, _ = crop_img(res, pil_img)
single_table_start_time = time.time()
# logger.info("------------------table recognition processing begins-----------------")
latex_code = None
html_code = None
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad():
table_result = self.table_model.predict(new_image, "html")
if len(table_result) > 0:
html_code = table_result[0]
else:
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image)
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image)
run_time = time.time() - single_table_start_time
# logger.info(f"------------table recognition processing ends within {run_time}s-----")
if run_time > self.table_max_time:
logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s")
# 判断是否返回正常
if latex_code:
expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
if expected_ending:
res["latex"] = latex_code
else:
logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
elif html_code:
if html_code:
expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
if expected_ending:
res["html"] = html_code
else:
logger.warning(f"table recognition processing fails, not found expected HTML table end")
else:
logger.warning(f"table recognition processing fails, not get latex or html return")
logger.warning(f"table recognition processing fails, not get html return")
logger.info(f"table time: {round(time.time() - table_start, 2)}")
logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
return layout_res
import re
def layout_rm_equation(layout_res):
rm_idxs = []
for idx, ele in enumerate(layout_res['layout_dets']):
if ele['category_id'] == 10:
rm_idxs.append(idx)
for idx in rm_idxs[::-1]:
del layout_res['layout_dets'][idx]
return layout_res
def get_croped_image(image_pil, bbox):
x_min, y_min, x_max, y_max = bbox
croped_img = image_pil.crop((x_min, y_min, x_max, y_max))
return croped_img
def latex_rm_whitespace(s: str):
"""Remove unnecessary whitespace from LaTeX code.
"""
text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
letter = '[a-zA-Z]'
noletter = '[\W_^\d]'
names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
news = s
while True:
s = news
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
if news == s:
break
return s
\ No newline at end of file
from doclayout_yolo import YOLOv10
class DocLayoutYOLOModel(object):
def __init__(self, weight, device):
self.model = YOLOv10(weight)
self.device = device
def predict(self, image):
layout_res = []
doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(),
doclayout_yolo_res.boxes.cls.cpu()):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
'category_id': int(cla.item()),
'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
'score': round(float(conf.item()), 3),
}
layout_res.append(new_item)
return layout_res
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment