diff --git a/Dockerfile b/Dockerfile
index 5366199c326e31f4d34993a2c86c1143ff26ccf4..7cef3b525533dab886f38d2e779a5673ea6b62ef 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -42,7 +42,7 @@ RUN /bin/bash -c "wget https://gitee.com/myhloli/MinerU/raw/master/magic-pdf.tem
# Download models and update the configuration file
RUN /bin/bash -c "pip3 install modelscope && \
- wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models.py && \
+ wget https://gitee.com/myhloli/MinerU/raw/master/scripts/download_models.py && \
python3 download_models.py && \
sed -i 's|/tmp/models|/root/.cache/modelscope/hub/opendatalab/PDF-Extract-Kit/models|g' /root/magic-pdf.json && \
sed -i 's|cpu|cuda|g' /root/magic-pdf.json"
diff --git a/README.md b/README.md
index 5dcb386ae79b52e0234ae35751ec66f6e48d01d8..abf8e6e87a57d18b4b7dd5d45b96351d1ab73ed8 100644
--- a/README.md
+++ b/README.md
@@ -20,6 +20,7 @@
[](https://colab.research.google.com/gist/myhloli/3b3a00a4a0a61577b6c30f989092d20d/mineru_demo.ipynb)
[](https://arxiv.org/abs/2409.18839)
+
@@ -41,6 +42,7 @@
# Changelog
+- 2024/11/15 0.9.3 released. Integrated [RapidTable](https://github.com/RapidAI/RapidTable) for table recognition, improving single-table parsing speed by more than 10 times, with higher accuracy and lower GPU memory usage.
- 2024/11/06 0.9.2 released. Integrated the [StructTable-InternVL2-1B](https://huggingface.co/U4R/StructTable-InternVL2-1B) model for table recognition functionality.
- 2024/10/31 0.9.0 released. This is a major new version with extensive code refactoring, addressing numerous issues, improving performance, reducing hardware requirements, and enhancing usability:
- Refactored the sorting module code to use [layoutreader](https://github.com/ppaanngggg/layoutreader) for reading order sorting, ensuring high accuracy in various layouts.
@@ -75,10 +77,12 @@
| Operating System | +|||||
| Ubuntu 22.04 LTS | +Windows 10 / 11 | +macOS 11+ | +|||
| CPU | +x86_64(unsupported ARM Linux) | +x86_64(unsupported ARM Windows) | +x86_64 / arm64 | +||
| Memory | +16GB or more, recommended 32GB+ | +||||
| Python Version | +3.10(Please make sure to create a Python 3.10 virtual environment using conda) | +||||
| Nvidia Driver Version | +latest (Proprietary Driver) | +latest | +None | +||
| CUDA Environment | +Automatic installation [12.1 (pytorch) + 11.8 (paddle)] | +11.8 (manual installation) + cuDNN v8.7.0 (manual installation) | +None | +||
| GPU Hardware Support List | +Minimum Requirement 8G+ VRAM | +3060ti/3070/4060 + 8G VRAM enables layout, formula recognition acceleration and OCR acceleration |
+ None | +||
| Recommended Configuration 10G+ VRAM | +3080/3080ti/3090/3090ti/4070/4070ti/4070tisuper/4080/4090 + 10G VRAM or more can enable layout, formula recognition, OCR acceleration and table recognition acceleration simultaneously + |
+ ||||
| 操作系统 | +|||||
| Ubuntu 22.04 LTS | +Windows 10 / 11 | +macOS 11+ | +|||
| CPU | +x86_64(暂不支持ARM Linux) | +x86_64(暂不支持ARM Windows) | +x86_64 / arm64 | +||
| 内存 | +大于等于16GB,推荐32G以上 | +||||
| python版本 | +3.10 (请务必通过conda创建3.10虚拟环境) | +||||
| Nvidia Driver 版本 | +latest(专有驱动) | +latest | +None | +||
| CUDA环境 | +自动安装[12.1(pytorch)+11.8(paddle)] | +11.8(手动安装)+cuDNN v8.7.0(手动安装) | +None | +||
| GPU硬件支持列表 | +最低要求 8G+显存 | +3060ti/3070/4060 + 8G显存可开启全部加速功能(表格仅限rapid_table) |
+ None | +||
| 推荐配置 10G+显存 | +3080/3080ti/3090/3090ti/4070/4070ti/4070tisuper/4080/4090 + 10G显存及以上可开启全部加速功能 + |
+ ||||
使用python脚本 从Hugging Face下载模型文件
pip install huggingface_hub
-wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models_hf.py -O download_models_hf.py
+wget https://gitee.com/myhloli/MinerU/raw/master/scripts/download_models_hf.py -O download_models_hf.py
python download_models_hf.py
@@ -18,7 +18,7 @@ python download_models_hf.py
```bash
pip install modelscope
-wget https://gitee.com/myhloli/MinerU/raw/master/docs/download_models.py -O download_models.py
+wget https://gitee.com/myhloli/MinerU/raw/master/scripts/download_models.py -O download_models.py
python download_models.py
```
diff --git a/magic-pdf.template.json b/magic-pdf.template.json
index 114dfce32af9b63a1bcb190fba41e1fe000eda45..cdb3dab6a0cba60ced35656fc031121491e8f318 100644
--- a/magic-pdf.template.json
+++ b/magic-pdf.template.json
@@ -15,7 +15,7 @@
"enable": true
},
"table-config": {
- "model": "tablemaster",
+ "model": "rapid_table",
"enable": false,
"max_time": 400
},
diff --git a/magic_pdf/dict2md/ocr_mkcontent.py b/magic_pdf/dict2md/ocr_mkcontent.py
index 2e2ce76ef46dbb71f1613c8dfe2d577da08224fc..63e11d72c3bcd04c50607e4dd1f3d4b45ecd6299 100644
--- a/magic_pdf/dict2md/ocr_mkcontent.py
+++ b/magic_pdf/dict2md/ocr_mkcontent.py
@@ -168,7 +168,7 @@ def merge_para_with_text(para_block):
# 如果是前一行带有-连字符,那么末尾不应该加空格
if __is_hyphen_at_line_end(content):
para_text += content[:-1]
- elif len(content) == 1 and content not in ['A', 'I', 'a', 'i']:
+ elif len(content) == 1 and content not in ['A', 'I', 'a', 'i'] and not content.isdigit():
para_text += content
else: # 西方文本语境下 content间需要空格分隔
para_text += f"{content} "
diff --git a/magic_pdf/libs/Constants.py b/magic_pdf/libs/Constants.py
index 0799f6fdeaf67425e3267b79cc563e422c3297b2..188465e80d194ea70a5aaaba4c09160b2926e85d 100644
--- a/magic_pdf/libs/Constants.py
+++ b/magic_pdf/libs/Constants.py
@@ -50,4 +50,6 @@ class MODEL_NAME:
YOLO_V8_MFD = "yolo_v8_mfd"
- UniMerNet_v2_Small = "unimernet_small"
\ No newline at end of file
+ UniMerNet_v2_Small = "unimernet_small"
+
+ RAPID_TABLE = "rapid_table"
\ No newline at end of file
diff --git a/magic_pdf/libs/config_reader.py b/magic_pdf/libs/config_reader.py
index 5e1a300d9f40e875f611b3452307091707719dea..b1126b647fced58f4bb78d35c6a45def2c8d03b1 100644
--- a/magic_pdf/libs/config_reader.py
+++ b/magic_pdf/libs/config_reader.py
@@ -92,7 +92,7 @@ def get_table_recog_config():
table_config = config.get('table-config')
if table_config is None:
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
- return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}')
+ return json.loads(f'{{"model": "{MODEL_NAME.RAPID_TABLE}","enable": false, "max_time": 400}}')
else:
return table_config
diff --git a/magic_pdf/model/pdf_extract_kit.py b/magic_pdf/model/pdf_extract_kit.py
index 21a07f5b31d8e9fb978b0c2fa901c3cfa2d905db..0c0e664bf432bb5f6dc856e47211061366218f6f 100644
--- a/magic_pdf/model/pdf_extract_kit.py
+++ b/magic_pdf/model/pdf_extract_kit.py
@@ -1,195 +1,28 @@
+import numpy as np
+import torch
from loguru import logger
import os
import time
-from pathlib import Path
-import shutil
-from magic_pdf.libs.Constants import *
-from magic_pdf.libs.clean_memory import clean_memory
-from magic_pdf.model.model_list import AtomicModel
+import cv2
+import yaml
+from PIL import Image
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
+
try:
- import cv2
- import yaml
- import argparse
- import numpy as np
- import torch
import torchtext
if torchtext.__version__ >= "0.18.0":
torchtext.disable_torchtext_deprecation_warning()
- from PIL import Image
- from torchvision import transforms
- from torch.utils.data import Dataset, DataLoader
- from ultralytics import YOLO
- from unimernet.common.config import Config
- import unimernet.tasks as tasks
- from unimernet.processors import load_processor
- from doclayout_yolo import YOLOv10
-
-except ImportError as e:
- logger.exception(e)
- logger.error(
- 'Required dependency not installed, please install by \n'
- '"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
- exit(1)
-
-from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
-from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
-from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
-from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
-from magic_pdf.model.ppTableModel import ppTableModel
-
-
-def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
- if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
- table_model = StructTableModel(model_path, max_time=max_time)
- elif table_model_type == MODEL_NAME.TABLE_MASTER:
- config = {
- "model_dir": model_path,
- "device": _device_
- }
- table_model = ppTableModel(config)
- else:
- logger.error("table model type not allow")
- exit(1)
- return table_model
-
-
-def mfd_model_init(weight):
- mfd_model = YOLO(weight)
- return mfd_model
-
-
-def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
- args = argparse.Namespace(cfg_path=cfg_path, options=None)
- cfg = Config(args)
- cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
- cfg.config.model.model_config.model_name = weight_dir
- cfg.config.model.tokenizer_config.path = weight_dir
- task = tasks.setup_task(cfg)
- model = task.build_model(cfg)
- model.to(_device_)
- model.eval()
- vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
- mfr_transform = transforms.Compose([vis_processor, ])
- return [model, mfr_transform]
-
-
-def layout_model_init(weight, config_file, device):
- model = Layoutlmv3_Predictor(weight, config_file, device)
- return model
-
-
-def doclayout_yolo_model_init(weight):
- model = YOLOv10(weight)
- return model
-
-
-def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None, use_dilation=True, det_db_unclip_ratio=1.8):
- if lang is not None:
- model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
- else:
- model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, use_dilation=use_dilation, det_db_unclip_ratio=det_db_unclip_ratio)
- return model
-
-
-class MathDataset(Dataset):
- def __init__(self, image_paths, transform=None):
- self.image_paths = image_paths
- self.transform = transform
-
- def __len__(self):
- return len(self.image_paths)
-
- def __getitem__(self, idx):
- # if not pil image, then convert to pil image
- if isinstance(self.image_paths[idx], str):
- raw_image = Image.open(self.image_paths[idx])
- else:
- raw_image = self.image_paths[idx]
- if self.transform:
- image = self.transform(raw_image)
- return image
-
-
-class AtomModelSingleton:
- _instance = None
- _models = {}
-
- def __new__(cls, *args, **kwargs):
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- return cls._instance
-
- def get_atom_model(self, atom_model_name: str, **kwargs):
- lang = kwargs.get("lang", None)
- layout_model_name = kwargs.get("layout_model_name", None)
- key = (atom_model_name, layout_model_name, lang)
- if key not in self._models:
- self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
- return self._models[key]
-
-
-def atom_model_init(model_name: str, **kwargs):
-
- if model_name == AtomicModel.Layout:
- if kwargs.get("layout_model_name") == MODEL_NAME.LAYOUTLMv3:
- atom_model = layout_model_init(
- kwargs.get("layout_weights"),
- kwargs.get("layout_config_file"),
- kwargs.get("device")
- )
- elif kwargs.get("layout_model_name") == MODEL_NAME.DocLayout_YOLO:
- atom_model = doclayout_yolo_model_init(
- kwargs.get("doclayout_yolo_weights"),
- )
- elif model_name == AtomicModel.MFD:
- atom_model = mfd_model_init(
- kwargs.get("mfd_weights")
- )
- elif model_name == AtomicModel.MFR:
- atom_model = mfr_model_init(
- kwargs.get("mfr_weight_dir"),
- kwargs.get("mfr_cfg_path"),
- kwargs.get("device")
- )
- elif model_name == AtomicModel.OCR:
- atom_model = ocr_model_init(
- kwargs.get("ocr_show_log"),
- kwargs.get("det_db_box_thresh"),
- kwargs.get("lang")
- )
- elif model_name == AtomicModel.Table:
- atom_model = table_model_init(
- kwargs.get("table_model_name"),
- kwargs.get("table_model_path"),
- kwargs.get("table_max_time"),
- kwargs.get("device")
- )
- else:
- logger.error("model name not allow")
- exit(1)
-
- return atom_model
-
+except ImportError:
+ pass
-# Unified crop img logic
-def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
- crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
- crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
- # Create a white background with an additional width and height of 50
- crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
- crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
- return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
-
- # Crop image
- crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
- cropped_img = input_pil_img.crop(crop_box)
- return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
- return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
- return return_image, return_list
+from magic_pdf.libs.Constants import *
+from magic_pdf.model.model_list import AtomicModel
+from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
+from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram
+from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
class CustomPEKModel:
@@ -226,7 +59,7 @@ class CustomPEKModel:
self.table_config = kwargs.get("table_config")
self.apply_table = self.table_config.get("enable", False)
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
- self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER)
+ self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE)
# ocr config
self.apply_ocr = ocr
@@ -235,7 +68,8 @@ class CustomPEKModel:
logger.info(
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
"apply_table: {}, table_model: {}, lang: {}".format(
- self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name, self.lang
+ self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
+ self.lang
)
)
# 初始化解析方案
@@ -248,17 +82,17 @@ class CustomPEKModel:
# 初始化公式识别
if self.apply_formula:
-
# 初始化公式检测模型
self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD,
- mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name]))
+ mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])),
+ device=self.device
)
# 初始化公式解析模型
mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
- self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
+ self.mfr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFR,
mfr_weight_dir=mfr_weight_dir,
mfr_cfg_path=mfr_cfg_path,
@@ -278,7 +112,8 @@ class CustomPEKModel:
self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout,
layout_model_name=MODEL_NAME.DocLayout_YOLO,
- doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
+ doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
+ device=self.device
)
# 初始化ocr
if self.apply_ocr:
@@ -305,26 +140,15 @@ class CustomPEKModel:
page_start = time.time()
- latex_filling_list = []
- mf_image_list = []
-
# layout检测
layout_start = time.time()
+ layout_res = []
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
# layoutlmv3
layout_res = self.layout_model(image, ignore_catids=[])
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
# doclayout_yolo
- layout_res = []
- doclayout_yolo_res = self.layout_model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
- for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(), doclayout_yolo_res.boxes.cls.cpu()):
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
- new_item = {
- 'category_id': int(cla.item()),
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
- 'score': round(float(conf.item()), 3),
- }
- layout_res.append(new_item)
+ layout_res = self.layout_model.predict(image)
layout_cost = round(time.time() - layout_start, 2)
logger.info(f"layout detection time: {layout_cost}")
@@ -333,59 +157,21 @@ class CustomPEKModel:
if self.apply_formula:
# 公式检测
mfd_start = time.time()
- mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
+ mfd_res = self.mfd_model.predict(image)
logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
- for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
- xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
- new_item = {
- 'category_id': 13 + int(cla.item()),
- 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
- 'score': round(float(conf.item()), 2),
- 'latex': '',
- }
- layout_res.append(new_item)
- latex_filling_list.append(new_item)
- bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
- mf_image_list.append(bbox_img)
# 公式识别
mfr_start = time.time()
- dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
- dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
- mfr_res = []
- for mf_img in dataloader:
- mf_img = mf_img.to(self.device)
- with torch.no_grad():
- output = self.mfr_model.generate({'image': mf_img})
- mfr_res.extend(output['pred_str'])
- for res, latex in zip(latex_filling_list, mfr_res):
- res['latex'] = latex_rm_whitespace(latex)
+ formula_list = self.mfr_model.predict(mfd_res, image)
+ layout_res.extend(formula_list)
mfr_cost = round(time.time() - mfr_start, 2)
- logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
-
- # Select regions for OCR / formula regions / table regions
- ocr_res_list = []
- table_res_list = []
- single_page_mfdetrec_res = []
- for res in layout_res:
- if int(res['category_id']) in [13, 14]:
- single_page_mfdetrec_res.append({
- "bbox": [int(res['poly'][0]), int(res['poly'][1]),
- int(res['poly'][4]), int(res['poly'][5])],
- })
- elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
- ocr_res_list.append(res)
- elif int(res['category_id']) in [5]:
- table_res_list.append(res)
-
- if torch.cuda.is_available() and self.device != 'cpu':
- properties = torch.cuda.get_device_properties(self.device)
- total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
- if total_memory <= 10:
- gc_start = time.time()
- clean_memory()
- gc_time = round(time.time() - gc_start, 2)
- logger.info(f"gc time: {gc_time}")
+ logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}")
+
+ # 清理显存
+ clean_vram(self.device, vram_threshold=8)
+
+ # 从layout_res中获取ocr区域、表格区域、公式区域
+ ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res)
# ocr识别
if self.apply_ocr:
@@ -393,23 +179,7 @@ class CustomPEKModel:
# Process each area that requires OCR processing
for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
- paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
- # Adjust the coordinates of the formula area
- adjusted_mfdetrec_res = []
- for mf_res in single_page_mfdetrec_res:
- mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
- # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
- x0 = mf_xmin - xmin + paste_x
- y0 = mf_ymin - ymin + paste_y
- x1 = mf_xmax - xmin + paste_x
- y1 = mf_ymax - ymin + paste_y
- # Filter formula blocks outside the graph
- if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
- continue
- else:
- adjusted_mfdetrec_res.append({
- "bbox": [x0, y0, x1, y1],
- })
+ adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
# OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
@@ -417,22 +187,8 @@ class CustomPEKModel:
# Integration results
if ocr_res:
- for box_ocr_res in ocr_res:
- p1, p2, p3, p4 = box_ocr_res[0]
- text, score = box_ocr_res[1]
-
- # Convert the coordinates back to the original coordinate system
- p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
- p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
- p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
- p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
-
- layout_res.append({
- 'category_id': 15,
- 'poly': p1 + p2 + p3 + p4,
- 'score': round(score, 2),
- 'text': text,
- })
+ ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
+ layout_res.extend(ocr_result_list)
ocr_cost = round(time.time() - ocr_start, 2)
logger.info(f"ocr time: {ocr_cost}")
@@ -443,41 +199,30 @@ class CustomPEKModel:
for res in table_res_list:
new_image, _ = crop_img(res, pil_img)
single_table_start_time = time.time()
- # logger.info("------------------table recognition processing begins-----------------")
- latex_code = None
html_code = None
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
with torch.no_grad():
table_result = self.table_model.predict(new_image, "html")
if len(table_result) > 0:
html_code = table_result[0]
- else:
+ elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image)
-
+ elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
+ html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image)
run_time = time.time() - single_table_start_time
- # logger.info(f"------------table recognition processing ends within {run_time}s-----")
if run_time > self.table_max_time:
- logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
+ logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s")
# 判断是否返回正常
-
- if latex_code:
- expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
- if expected_ending:
- res["latex"] = latex_code
- else:
- logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
- elif html_code:
+ if html_code:
expected_ending = html_code.strip().endswith('