Commit 79c8a5c8 authored by myhloli's avatar myhloli
Browse files

feat(table): upgrade RapidTable to1.0.3 and add sub-model support

- Update RapidTable dependency to version 1.0.3
- Add support for sub-models in RapidTable
- Update magic-pdf configuration to include table sub-model
- Modify table model initialization to support sub-models
- Update table prediction logic to handle new output format
parent 46ce94eb
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
}, },
"table-config": { "table-config": {
"model": "rapid_table", "model": "rapid_table",
"sub_model": "slanet_plus",
"enable": true, "enable": true,
"max_time": 400 "max_time": 400
}, },
...@@ -39,5 +40,5 @@ ...@@ -39,5 +40,5 @@
"enable": false "enable": false
} }
}, },
"config_version": "1.1.0" "config_version": "1.1.1"
} }
\ No newline at end of file
...@@ -69,6 +69,7 @@ class CustomPEKModel: ...@@ -69,6 +69,7 @@ class CustomPEKModel:
self.apply_table = self.table_config.get('enable', False) self.apply_table = self.table_config.get('enable', False)
self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE) self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE) self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
self.table_sub_model_name = self.table_config.get('sub_model', None)
# ocr config # ocr config
self.apply_ocr = ocr self.apply_ocr = ocr
...@@ -174,6 +175,7 @@ class CustomPEKModel: ...@@ -174,6 +175,7 @@ class CustomPEKModel:
table_max_time=self.table_max_time, table_max_time=self.table_max_time,
device=self.device, device=self.device,
ocr_engine=self.ocr_model, ocr_engine=self.ocr_model,
table_sub_model_name=self.table_sub_model_name
) )
logger.info('DocAnalysis init done!') logger.info('DocAnalysis init done!')
...@@ -276,7 +278,7 @@ class CustomPEKModel: ...@@ -276,7 +278,7 @@ class CustomPEKModel:
elif self.table_model_name == MODEL_NAME.TABLE_MASTER: elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image) html_code = self.table_model.img2html(new_image)
elif self.table_model_name == MODEL_NAME.RAPID_TABLE: elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = self.table_model.predict( html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
new_image new_image
) )
run_time = time.time() - single_table_start_time run_time = time.time() - single_table_start_time
......
...@@ -21,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \ ...@@ -21,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None): def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE: if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time) table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER: elif table_model_type == MODEL_NAME.TABLE_MASTER:
...@@ -31,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr ...@@ -31,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
} }
table_model = TableMasterPaddleModel(config) table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE: elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTableModel(ocr_engine) table_model = RapidTableModel(ocr_engine, table_sub_model_name)
else: else:
logger.error('table model type not allow') logger.error('table model type not allow')
exit(1) exit(1)
...@@ -163,7 +163,8 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -163,7 +163,8 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_path'), kwargs.get('table_model_path'),
kwargs.get('table_max_time'), kwargs.get('table_max_time'),
kwargs.get('device'), kwargs.get('device'),
kwargs.get('ocr_engine') kwargs.get('ocr_engine'),
kwargs.get('table_sub_model_name')
) )
elif model_name == AtomicModel.LangDetect: elif model_name == AtomicModel.LangDetect:
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect: if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
......
...@@ -2,12 +2,25 @@ import cv2 ...@@ -2,12 +2,25 @@ import cv2
import numpy as np import numpy as np
import torch import torch
from loguru import logger from loguru import logger
from rapid_table import RapidTable from rapid_table import RapidTable, RapidTableInput
from rapid_table.main import ModelType
class RapidTableModel(object): class RapidTableModel(object):
def __init__(self, ocr_engine): def __init__(self, ocr_engine, table_sub_model_name):
self.table_model = RapidTable() sub_model_list = [model.value for model in ModelType]
if table_sub_model_name is None:
input_args = RapidTableInput()
elif table_sub_model_name in sub_model_list:
if torch.cuda.is_available() and table_sub_model_name == "unitable":
input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True)
else:
input_args = RapidTableInput(model_type=table_sub_model_name)
else:
raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
self.table_model = RapidTable(input_args)
# if ocr_engine is None: # if ocr_engine is None:
# self.ocr_model_name = "RapidOCR" # self.ocr_model_name = "RapidOCR"
# if torch.cuda.is_available(): # if torch.cuda.is_available():
...@@ -45,7 +58,11 @@ class RapidTableModel(object): ...@@ -45,7 +58,11 @@ class RapidTableModel(object):
ocr_result = None ocr_result = None
if ocr_result: if ocr_result:
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result) table_results = self.table_model(np.asarray(image), ocr_result)
return html_code, table_cell_bboxes, elapse html_code = table_results.pred_html
table_cell_bboxes = table_results.cell_bboxes
logic_points = table_results.logic_points
elapse = table_results.elapse
return html_code, table_cell_bboxes, logic_points, elapse
else: else:
return None, None, None return None, None, None, None
...@@ -51,7 +51,7 @@ if __name__ == '__main__': ...@@ -51,7 +51,7 @@ if __name__ == '__main__':
"doclayout_yolo==0.0.2b1", # doclayout_yolo "doclayout_yolo==0.0.2b1", # doclayout_yolo
"rapidocr-paddle", # rapidocr-paddle "rapidocr-paddle", # rapidocr-paddle
"rapidocr_onnxruntime", "rapidocr_onnxruntime",
"rapid_table==0.3.0", # rapid_table "rapid_table>=1.0.3,<2.0.0", # rapid_table
"PyYAML", # yaml "PyYAML", # yaml
"openai", # openai SDK "openai", # openai SDK
"detectron2" "detectron2"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment