Commit 5252c46e authored by myhloli's avatar myhloli
Browse files

refactor(ocr): comment out print statements and update table model initialization

- Comment out print statements in base_ocr_v20.py and pytorch_paddle.py
- Update table model initialization to use lang parameter instead of ocr_engine
- Remove unused RapidOCR initialization in rapid_table.py
parent 9b3339f1
......@@ -36,7 +36,7 @@ from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableM
# from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None, table_sub_model_name=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
......@@ -48,6 +48,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
}
table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE:
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang=lang
)
table_model = RapidTableModel(ocr_engine, table_sub_model_name)
else:
logger.error('table model type not allow')
......@@ -134,7 +142,7 @@ class AtomModelSingleton:
elif atom_model_name in [AtomicModel.Layout]:
key = (atom_model_name, layout_model_name)
elif atom_model_name in [AtomicModel.Table]:
key = (atom_model_name, table_model_name)
key = (atom_model_name, table_model_name, lang)
else:
key = atom_model_name
......@@ -182,7 +190,7 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_path'),
kwargs.get('table_max_time'),
kwargs.get('device'),
kwargs.get('ocr_engine'),
kwargs.get('lang'),
kwargs.get('table_sub_model_name')
)
elif model_name == AtomicModel.LangDetect:
......
......@@ -109,7 +109,7 @@ class PytorchPaddleOCR(TextSystem):
for img in imgs:
img = preprocess_image(img)
dt_boxes, elapse = self.text_detector(img)
logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
# logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
if dt_boxes is None:
ocr_res.append(None)
continue
......@@ -128,7 +128,7 @@ class PytorchPaddleOCR(TextSystem):
img = preprocess_image(img)
img = [img]
rec_res, elapse = self.text_recognizer(img)
logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
# logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
ocr_res.append(rec_res)
return ocr_res
......@@ -146,7 +146,7 @@ class PytorchPaddleOCR(TextSystem):
return None, None
else:
pass
logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
# logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
img_crop_list = []
dt_boxes = sorted_boxes(dt_boxes)
......@@ -163,7 +163,7 @@ class PytorchPaddleOCR(TextSystem):
img_crop_list.append(img_crop)
rec_res, elapse = self.text_recognizer(img_crop_list)
logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
# logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res):
......
......@@ -27,11 +27,11 @@ class BaseOCRV20:
def load_state_dict(self, weights):
self.net.load_state_dict(weights)
print('weights is loaded.')
# print('weights is loaded.')
def load_pytorch_weights(self, weights_path):
self.net.load_state_dict(torch.load(weights_path, weights_only=True))
print('model is loaded: {}'.format(weights_path))
# print('model is loaded: {}'.format(weights_path))
def inference(self, inputs):
with torch.no_grad():
......
......@@ -23,25 +23,17 @@ class RapidTableModel(object):
self.table_model = RapidTable(input_args)
# if ocr_engine is None:
# self.ocr_model_name = "RapidOCR"
# if torch.cuda.is_available():
# from rapidocr_paddle import RapidOCR
# self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
# else:
# from rapidocr_onnxruntime import RapidOCR
# self.ocr_engine = RapidOCR()
# self.ocr_model_name = "RapidOCR"
# if torch.cuda.is_available():
# from rapidocr_paddle import RapidOCR
# self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
# else:
# self.ocr_model_name = "PaddleOCR"
# self.ocr_engine = ocr_engine
# from rapidocr_onnxruntime import RapidOCR
# self.ocr_engine = RapidOCR()
self.ocr_model_name = "PaddleOCR"
self.ocr_engine = ocr_engine
self.ocr_model_name = "RapidOCR"
if torch.cuda.is_available():
from rapidocr_paddle import RapidOCR
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
else:
from rapidocr_onnxruntime import RapidOCR
self.ocr_engine = RapidOCR()
def predict(self, image):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment