Unverified Commit 230191c7 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1556 from myhloli/dev

feat(table): upgrade RapidTable to1.0.3 and add sub-model support
parents 63c267fa 452a9c0b
......@@ -19,7 +19,7 @@ einops
accelerate
rapidocr-paddle
rapidocr-onnxruntime
rapid_table==0.3.0
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
openai
detectron2
......@@ -18,7 +18,7 @@ einops
accelerate
rapidocr-paddle
rapidocr-onnxruntime
rapid_table==0.3.0
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
openai
detectron2
......@@ -18,7 +18,7 @@ einops
accelerate
rapidocr-paddle
rapidocr-onnxruntime
rapid_table==0.3.0
rapid-table>=1.0.3,<2.0.0
doclayout-yolo==0.0.2b1
openai
detectron2
......@@ -16,6 +16,7 @@
},
"table-config": {
"model": "rapid_table",
"sub_model": "slanet_plus",
"enable": true,
"max_time": 400
},
......@@ -39,5 +40,5 @@
"enable": false
}
},
"config_version": "1.1.0"
"config_version": "1.1.1"
}
\ No newline at end of file
......@@ -161,7 +161,7 @@ class BatchAnalyze:
elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.model.table_model.img2html(new_image)
elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = (
html_code, table_cell_bboxes, logic_points, elapse = (
self.model.table_model.predict(new_image)
)
run_time = time.time() - single_table_start_time
......
......@@ -69,6 +69,7 @@ class CustomPEKModel:
self.apply_table = self.table_config.get('enable', False)
self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
self.table_sub_model_name = self.table_config.get('sub_model', None)
# ocr config
self.apply_ocr = ocr
......@@ -174,6 +175,7 @@ class CustomPEKModel:
table_max_time=self.table_max_time,
device=self.device,
ocr_engine=self.ocr_model,
table_sub_model_name=self.table_sub_model_name
)
logger.info('DocAnalysis init done!')
......@@ -276,7 +278,7 @@ class CustomPEKModel:
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image)
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
html_code, table_cell_bboxes, elapse = self.table_model.predict(
html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
new_image
)
run_time = time.time() - single_table_start_time
......
......@@ -21,7 +21,7 @@ from magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle import \
TableMasterPaddleModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None):
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER:
......@@ -31,7 +31,7 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
}
table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTableModel(ocr_engine)
table_model = RapidTableModel(ocr_engine, table_sub_model_name)
else:
logger.error('table model type not allow')
exit(1)
......@@ -163,7 +163,8 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_path'),
kwargs.get('table_max_time'),
kwargs.get('device'),
kwargs.get('ocr_engine')
kwargs.get('ocr_engine'),
kwargs.get('table_sub_model_name')
)
elif model_name == AtomicModel.LangDetect:
if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
......
......@@ -2,12 +2,25 @@ import cv2
import numpy as np
import torch
from loguru import logger
from rapid_table import RapidTable
from rapid_table import RapidTable, RapidTableInput
from rapid_table.main import ModelType
class RapidTableModel(object):
def __init__(self, ocr_engine):
self.table_model = RapidTable()
def __init__(self, ocr_engine, table_sub_model_name):
sub_model_list = [model.value for model in ModelType]
if table_sub_model_name is None:
input_args = RapidTableInput()
elif table_sub_model_name in sub_model_list:
if torch.cuda.is_available() and table_sub_model_name == "unitable":
input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True)
else:
input_args = RapidTableInput(model_type=table_sub_model_name)
else:
raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
self.table_model = RapidTable(input_args)
# if ocr_engine is None:
# self.ocr_model_name = "RapidOCR"
# if torch.cuda.is_available():
......@@ -45,7 +58,11 @@ class RapidTableModel(object):
ocr_result = None
if ocr_result:
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
return html_code, table_cell_bboxes, elapse
table_results = self.table_model(np.asarray(image), ocr_result)
html_code = table_results.pred_html
table_cell_bboxes = table_results.cell_bboxes
logic_points = table_results.logic_points
elapse = table_results.elapse
return html_code, table_cell_bboxes, logic_points, elapse
else:
return None, None, None
return None, None, None, None
......@@ -51,7 +51,7 @@ if __name__ == '__main__':
"doclayout_yolo==0.0.2b1", # doclayout_yolo
"rapidocr-paddle", # rapidocr-paddle
"rapidocr_onnxruntime",
"rapid_table==0.3.0", # rapid_table
"rapid_table>=1.0.3,<2.0.0", # rapid_table
"PyYAML", # yaml
"openai", # openai SDK
"detectron2"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment