Unverified Commit 5a3872b2 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #915 from myhloli/dev

feat(table): add RapidOCR support for RapidTable model
parents 5e0c9d2e fe2c2c0d
......@@ -26,6 +26,7 @@ try:
from unimernet.processors import load_processor
from doclayout_yolo import YOLOv10
from rapid_table import RapidTable
from rapidocr_paddle import RapidOCR
except ImportError as e:
logger.exception(e)
......@@ -42,6 +43,7 @@ from magic_pdf.model.ppTableModel import ppTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
ocr_engine = None
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
table_model = StructTableModel(model_path, max_time=max_time)
elif table_model_type == MODEL_NAME.TABLE_MASTER:
......@@ -52,11 +54,15 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
table_model = ppTableModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE:
table_model = RapidTable()
ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
else:
logger.error("table model type not allow")
exit(1)
return table_model
if ocr_engine:
return [table_model, ocr_engine]
else:
return table_model
def mfd_model_init(weight):
......@@ -283,23 +289,32 @@ class CustomPEKModel:
doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
)
# 初始化ocr
# if self.apply_ocr:
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3,
lang=self.lang
)
if self.apply_ocr:
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3,
lang=self.lang
)
# init table model
if self.apply_table:
table_model_dir = self.configs["weights"][self.table_model_name]
self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
table_model_name=self.table_model_name,
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device
)
if self.table_model_name in [MODEL_NAME.STRUCT_EQTABLE, MODEL_NAME.TABLE_MASTER]:
self.table_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
table_model_name=self.table_model_name,
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device
)
elif self.table_model_name in [MODEL_NAME.RAPID_TABLE]:
self.table_model, self.ocr_engine =atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Table,
table_model_name=self.table_model_name,
table_model_path=str(os.path.join(models_dir, table_model_dir)),
table_max_time=self.table_max_time,
device=self.device
)
logger.info('DocAnalysis init done!')
......@@ -381,9 +396,8 @@ class CustomPEKModel:
table_res_list.append(res)
if torch.cuda.is_available() and self.device != 'cpu':
properties = torch.cuda.get_device_properties(self.device)
total_memory = properties.total_memory / (1024 ** 3) # 将字节转换为 GB
if total_memory <= 10:
total_memory = torch.cuda.get_device_properties(self.device).total_memory / (1024 ** 3) # 将字节转换为 GB
if total_memory <= 8:
gc_start = time.time()
clean_memory()
gc_time = round(time.time() - gc_start, 2)
......@@ -456,13 +470,8 @@ class CustomPEKModel:
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
html_code = self.table_model.img2html(new_image)
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
new_image_bgr = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
ocr_result = self.ocr_model.ocr(new_image_bgr)[0]
new_ocr_result = []
for box_ocr_res in ocr_result:
text, score = box_ocr_res[1]
new_ocr_result.append([box_ocr_res[0], text, score])
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(new_image), new_ocr_result)
ocr_result, _ = self.ocr_engine(np.asarray(new_image))
html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(new_image), ocr_result)
run_time = time.time() - single_table_start_time
# logger.info(f"------------table recognition processing ends within {run_time}s-----")
......
......@@ -47,6 +47,7 @@ if __name__ == '__main__':
"einops", # struct-eqtable依赖
"accelerate", # struct-eqtable依赖
"doclayout_yolo==0.0.2", # doclayout_yolo
"rapidocr-paddle", # rapidocr-paddle
"rapid_table", # rapid_table
"detectron2"
],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment