"examples/vscode:/vscode.git/clone" did not exist on "96a5e4dd795b675210b0d18f5e9fab69ec69bb6e"
Commit 5252c46e authored by myhloli's avatar myhloli
Browse files

refactor(ocr): comment out print statements and update table model initialization

- Comment out print statements in base_ocr_v20.py and pytorch_paddle.py
- Update table model initialization to use lang parameter instead of ocr_engine
- Remove unused RapidOCR initialization in rapid_table.py
parent 9b3339f1
...@@ -36,7 +36,7 @@ from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableM ...@@ -36,7 +36,7 @@ from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableM
# from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel # from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None): def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None, table_sub_model_name=None):
if table_model_type == MODEL_NAME.STRUCT_EQTABLE: if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time) table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
...@@ -48,6 +48,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr ...@@ -48,6 +48,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
} }
table_model = TableMasterPaddleModel(config) table_model = TableMasterPaddleModel(config)
elif table_model_type == MODEL_NAME.RAPID_TABLE: elif table_model_type == MODEL_NAME.RAPID_TABLE:
atom_model_manager = AtomModelSingleton()
ocr_engine = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
lang=lang
)
table_model = RapidTableModel(ocr_engine, table_sub_model_name) table_model = RapidTableModel(ocr_engine, table_sub_model_name)
else: else:
logger.error('table model type not allow') logger.error('table model type not allow')
...@@ -134,7 +142,7 @@ class AtomModelSingleton: ...@@ -134,7 +142,7 @@ class AtomModelSingleton:
elif atom_model_name in [AtomicModel.Layout]: elif atom_model_name in [AtomicModel.Layout]:
key = (atom_model_name, layout_model_name) key = (atom_model_name, layout_model_name)
elif atom_model_name in [AtomicModel.Table]: elif atom_model_name in [AtomicModel.Table]:
key = (atom_model_name, table_model_name) key = (atom_model_name, table_model_name, lang)
else: else:
key = atom_model_name key = atom_model_name
...@@ -182,7 +190,7 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -182,7 +190,7 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_path'), kwargs.get('table_model_path'),
kwargs.get('table_max_time'), kwargs.get('table_max_time'),
kwargs.get('device'), kwargs.get('device'),
kwargs.get('ocr_engine'), kwargs.get('lang'),
kwargs.get('table_sub_model_name') kwargs.get('table_sub_model_name')
) )
elif model_name == AtomicModel.LangDetect: elif model_name == AtomicModel.LangDetect:
......
...@@ -109,7 +109,7 @@ class PytorchPaddleOCR(TextSystem): ...@@ -109,7 +109,7 @@ class PytorchPaddleOCR(TextSystem):
for img in imgs: for img in imgs:
img = preprocess_image(img) img = preprocess_image(img)
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse)) # logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
if dt_boxes is None: if dt_boxes is None:
ocr_res.append(None) ocr_res.append(None)
continue continue
...@@ -128,7 +128,7 @@ class PytorchPaddleOCR(TextSystem): ...@@ -128,7 +128,7 @@ class PytorchPaddleOCR(TextSystem):
img = preprocess_image(img) img = preprocess_image(img)
img = [img] img = [img]
rec_res, elapse = self.text_recognizer(img) rec_res, elapse = self.text_recognizer(img)
logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse)) # logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
ocr_res.append(rec_res) ocr_res.append(rec_res)
return ocr_res return ocr_res
...@@ -146,7 +146,7 @@ class PytorchPaddleOCR(TextSystem): ...@@ -146,7 +146,7 @@ class PytorchPaddleOCR(TextSystem):
return None, None return None, None
else: else:
pass pass
logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse)) # logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
img_crop_list = [] img_crop_list = []
dt_boxes = sorted_boxes(dt_boxes) dt_boxes = sorted_boxes(dt_boxes)
...@@ -163,7 +163,7 @@ class PytorchPaddleOCR(TextSystem): ...@@ -163,7 +163,7 @@ class PytorchPaddleOCR(TextSystem):
img_crop_list.append(img_crop) img_crop_list.append(img_crop)
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse)) # logger.debug("rec_res num : {}, elapsed : {}".format(len(rec_res), elapse))
filter_boxes, filter_rec_res = [], [] filter_boxes, filter_rec_res = [], []
for box, rec_result in zip(dt_boxes, rec_res): for box, rec_result in zip(dt_boxes, rec_res):
......
...@@ -27,11 +27,11 @@ class BaseOCRV20: ...@@ -27,11 +27,11 @@ class BaseOCRV20:
def load_state_dict(self, weights): def load_state_dict(self, weights):
self.net.load_state_dict(weights) self.net.load_state_dict(weights)
print('weights is loaded.') # print('weights is loaded.')
def load_pytorch_weights(self, weights_path): def load_pytorch_weights(self, weights_path):
self.net.load_state_dict(torch.load(weights_path, weights_only=True)) self.net.load_state_dict(torch.load(weights_path, weights_only=True))
print('model is loaded: {}'.format(weights_path)) # print('model is loaded: {}'.format(weights_path))
def inference(self, inputs): def inference(self, inputs):
with torch.no_grad(): with torch.no_grad():
......
...@@ -23,7 +23,6 @@ class RapidTableModel(object): ...@@ -23,7 +23,6 @@ class RapidTableModel(object):
self.table_model = RapidTable(input_args) self.table_model = RapidTable(input_args)
# if ocr_engine is None:
# self.ocr_model_name = "RapidOCR" # self.ocr_model_name = "RapidOCR"
# if torch.cuda.is_available(): # if torch.cuda.is_available():
# from rapidocr_paddle import RapidOCR # from rapidocr_paddle import RapidOCR
...@@ -31,17 +30,10 @@ class RapidTableModel(object): ...@@ -31,17 +30,10 @@ class RapidTableModel(object):
# else: # else:
# from rapidocr_onnxruntime import RapidOCR # from rapidocr_onnxruntime import RapidOCR
# self.ocr_engine = RapidOCR() # self.ocr_engine = RapidOCR()
# else:
# self.ocr_model_name = "PaddleOCR"
# self.ocr_engine = ocr_engine
self.ocr_model_name = "RapidOCR" self.ocr_model_name = "PaddleOCR"
if torch.cuda.is_available(): self.ocr_engine = ocr_engine
from rapidocr_paddle import RapidOCR
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
else:
from rapidocr_onnxruntime import RapidOCR
self.ocr_engine = RapidOCR()
def predict(self, image): def predict(self, image):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment