Unverified Commit 74fba476 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #910 from myhloli/dev

feat(table): integrate RapidTable model for table recognition
parents 9581fcda e78edb19
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"enable": true "enable": true
}, },
"table-config": { "table-config": {
"model": "tablemaster", "model": "rapid_table",
"enable": false, "enable": false,
"max_time": 400 "max_time": 400
}, },
......
...@@ -50,4 +50,6 @@ class MODEL_NAME: ...@@ -50,4 +50,6 @@ class MODEL_NAME:
YOLO_V8_MFD = "yolo_v8_mfd" YOLO_V8_MFD = "yolo_v8_mfd"
UniMerNet_v2_Small = "unimernet_small" UniMerNet_v2_Small = "unimernet_small"
\ No newline at end of file
RAPID_TABLE = "rapid_table"
\ No newline at end of file
...@@ -92,7 +92,7 @@ def get_table_recog_config(): ...@@ -92,7 +92,7 @@ def get_table_recog_config():
table_config = config.get('table-config') table_config = config.get('table-config')
if table_config is None: if table_config is None:
logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default") logger.warning(f"'table-config' not found in {CONFIG_FILE_NAME}, use 'False' as default")
return json.loads(f'{{"model": "{MODEL_NAME.TABLE_MASTER}","enable": false, "max_time": 400}}') return json.loads(f'{{"model": "{MODEL_NAME.RAPID_TABLE}","enable": false, "max_time": 400}}')
else: else:
return table_config return table_config
......
from loguru import logger from loguru import logger
import os import os
import time import time
from pathlib import Path
import shutil
from magic_pdf.libs.Constants import * from magic_pdf.libs.Constants import *
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
...@@ -27,6 +25,7 @@ try: ...@@ -27,6 +25,7 @@ try:
import unimernet.tasks as tasks import unimernet.tasks as tasks
from unimernet.processors import load_processor from unimernet.processors import load_processor
from doclayout_yolo import YOLOv10 from doclayout_yolo import YOLOv10
from rapid_table import RapidTable
except ImportError as e: except ImportError as e:
logger.exception(e) logger.exception(e)
...@@ -51,9 +50,12 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'): ...@@ -51,9 +50,12 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
"device": _device_ "device": _device_
} }
table_model = ppTableModel(config) table_model = ppTableModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTable()
else: else:
logger.error("table model type not allow") logger.error("table model type not allow")
exit(1) exit(1)
return table_model return table_model
...@@ -226,7 +228,7 @@ class CustomPEKModel: ...@@ -226,7 +228,7 @@ class CustomPEKModel:
self.table_config = kwargs.get("table_config") self.table_config = kwargs.get("table_config")
self.apply_table = self.table_config.get("enable", False) self.apply_table = self.table_config.get("enable", False)
self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE) self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
self.table_model_name = self.table_config.get("model", MODEL_NAME.TABLE_MASTER) self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE)
# ocr config # ocr config
self.apply_ocr = ocr self.apply_ocr = ocr
...@@ -281,13 +283,13 @@ class CustomPEKModel: ...@@ -281,13 +283,13 @@ class CustomPEKModel:
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])) doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
) )
# 初始化ocr # 初始化ocr
if self.apply_ocr: # if self.apply_ocr:
self.ocr_model = atom_model_manager.get_atom_model( self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR, atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log, ocr_show_log=show_log,
det_db_box_thresh=0.3, det_db_box_thresh=0.3,
lang=self.lang lang=self.lang
) )
# init table model # init table model
if self.apply_table: if self.apply_table:
table_model_dir = self.configs["weights"][self.table_model_name] table_model_dir = self.configs["weights"][self.table_model_name]
...@@ -451,8 +453,16 @@ class CustomPEKModel: ...@@ -451,8 +453,16 @@ class CustomPEKModel:
table_result = self.table_model.predict(new_image, "html") table_result = self.table_model.predict(new_image, "html")
if len(table_result) > 0: if len(table_result) > 0:
html_code = table_result[0] html_code = table_result[0]
else: elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image) html_code = self.table_model.img2html(new_image)
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
new_image_bgr = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
ocr_result = self.ocr_model.ocr(new_image_bgr)[0]
new_ocr_result = []
for box_ocr_res in ocr_result:
text, score = box_ocr_res[1]
new_ocr_result.append([box_ocr_res[0], text, score])
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(new_image), new_ocr_result)
run_time = time.time() - single_table_start_time run_time = time.time() - single_table_start_time
# logger.info(f"------------table recognition processing ends within {run_time}s-----") # logger.info(f"------------table recognition processing ends within {run_time}s-----")
......
...@@ -4,4 +4,5 @@ weights: ...@@ -4,4 +4,5 @@ weights:
yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
unimernet_small: MFR/unimernet_small unimernet_small: MFR/unimernet_small
struct_eqtable: TabRec/StructEqTable struct_eqtable: TabRec/StructEqTable
tablemaster: TabRec/TableMaster tablemaster: TabRec/TableMaster
\ No newline at end of file rapid_table: TabRec/RapidTable
\ No newline at end of file
...@@ -47,6 +47,7 @@ if __name__ == '__main__': ...@@ -47,6 +47,7 @@ if __name__ == '__main__':
"einops", # struct-eqtable依赖 "einops", # struct-eqtable依赖
"accelerate", # struct-eqtable依赖 "accelerate", # struct-eqtable依赖
"doclayout_yolo==0.0.2", # doclayout_yolo "doclayout_yolo==0.0.2", # doclayout_yolo
"rapid_table", # rapid_table
"detectron2" "detectron2"
], ],
}, },
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment