Commit 512adb67 authored by myhloli's avatar myhloli
Browse files

feat(model): add onnxruntime support for paddleocr on cpu

- Implement ONNXModelSingleton to manage ONNX models
- Modify ModifiedPaddleOCR to use ONNX models on ARM CPUs without CUDA
- Update RapidTableModel to use RapidOCR with ONNXRuntime on CPU
- Add rapidocr_onnxruntime dependency in setup.py
parent 7c5cdcd4
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
"layoutreader-model-dir":"/tmp/layoutreader", "layoutreader-model-dir":"/tmp/layoutreader",
"device-mode":"cpu", "device-mode":"cpu",
"layout-config": { "layout-config": {
"model": "layoutlmv3" "model": "doclayout_yolo"
}, },
"formula-config": { "formula-config": {
"mfd_model": "yolo_v8_mfd", "mfd_model": "yolo_v8_mfd",
......
...@@ -70,11 +70,6 @@ def ocr_model_init(show_log: bool = False, ...@@ -70,11 +70,6 @@ def ocr_model_init(show_log: bool = False,
det_db_unclip_ratio=1.8, det_db_unclip_ratio=1.8,
): ):
# use_npu = False
# device = get_device()
# if str(device).startswith("npu"):
# use_npu = True
if lang is not None and lang != '': if lang is not None and lang != '':
model = ModifiedPaddleOCR( model = ModifiedPaddleOCR(
show_log=show_log, show_log=show_log,
...@@ -82,7 +77,6 @@ def ocr_model_init(show_log: bool = False, ...@@ -82,7 +77,6 @@ def ocr_model_init(show_log: bool = False,
lang=lang, lang=lang,
use_dilation=use_dilation, use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio, det_db_unclip_ratio=det_db_unclip_ratio,
# use_npu=use_npu,
) )
else: else:
model = ModifiedPaddleOCR( model = ModifiedPaddleOCR(
...@@ -90,7 +84,6 @@ def ocr_model_init(show_log: bool = False, ...@@ -90,7 +84,6 @@ def ocr_model_init(show_log: bool = False,
det_db_box_thresh=det_db_box_thresh, det_db_box_thresh=det_db_box_thresh,
use_dilation=use_dilation, use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio, det_db_unclip_ratio=det_db_unclip_ratio,
# use_npu=use_npu,
) )
return model return model
...@@ -160,6 +153,7 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -160,6 +153,7 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_path'), kwargs.get('table_model_path'),
kwargs.get('table_max_time'), kwargs.get('table_max_time'),
kwargs.get('device'), kwargs.get('device'),
kwargs.get('ocr_engine')
) )
else: else:
logger.error('model name not allow') logger.error('model name not allow')
......
...@@ -303,4 +303,54 @@ def calculate_is_angle(poly): ...@@ -303,4 +303,54 @@ def calculate_is_angle(poly):
return False return False
else: else:
# logger.info((p3[1] - p1[1])/height) # logger.info((p3[1] - p1[1])/height)
return True return True
\ No newline at end of file
class ONNXModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_onnx_model(self, **kwargs):
lang = kwargs.get('lang', None)
det_db_box_thresh = kwargs.get('det_db_box_thresh', 0.3)
use_dilation = kwargs.get('use_dilation', True)
det_db_unclip_ratio = kwargs.get('det_db_unclip_ratio', 1.8)
key = (lang, det_db_box_thresh, use_dilation, det_db_unclip_ratio)
if key not in self._models:
self._models[key] = onnx_model_init(key)
return self._models[key]
def onnx_model_init(key):
import importlib.resources
resource_path = importlib.resources.path('rapidocr_onnxruntime.models','')
onnx_model = None
additional_ocr_params = {
"use_onnx": True,
"det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
"rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
"cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
"det_db_box_thresh": key[1],
"use_dilation": key[2],
"det_db_unclip_ratio": key[3],
}
logger.info(f"additional_ocr_params: {additional_ocr_params}")
if key[0] is not None:
additional_ocr_params["lang"] = key[0]
from paddleocr import PaddleOCR
onnx_model = PaddleOCR(**additional_ocr_params)
if onnx_model is None:
logger.error('model init failed')
exit(1)
else:
return onnx_model
\ No newline at end of file
import copy import copy
import platform
import time import time
import cv2 import cv2
import numpy as np import numpy as np
import torch
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
...@@ -9,12 +11,23 @@ from ppocr.utils.utility import alpha_to_color, binarize_img ...@@ -9,12 +11,23 @@ from ppocr.utils.utility import alpha_to_color, binarize_img
from tools.infer.predict_system import sorted_boxes from tools.infer.predict_system import sorted_boxes
from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img, \
ONNXModelSingleton
logger = get_logger() logger = get_logger()
class ModifiedPaddleOCR(PaddleOCR): class ModifiedPaddleOCR(PaddleOCR):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# 在cpu架构为arm且不支持cuda时调用onnx、
if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
self.use_onnx = True
onnx_model_manager = ONNXModelSingleton()
self.additional_ocr = onnx_model_manager.get_onnx_model(**kwargs)
def ocr(self, def ocr(self,
img, img,
det=True, det=True,
...@@ -79,7 +92,10 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -79,7 +92,10 @@ class ModifiedPaddleOCR(PaddleOCR):
ocr_res = [] ocr_res = []
for img in imgs: for img in imgs:
img = preprocess_image(img) img = preprocess_image(img)
dt_boxes, elapse = self.text_detector(img) if self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
if dt_boxes is None: if dt_boxes is None:
ocr_res.append(None) ocr_res.append(None)
continue continue
...@@ -106,7 +122,10 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -106,7 +122,10 @@ class ModifiedPaddleOCR(PaddleOCR):
img, cls_res_tmp, elapse = self.text_classifier(img) img, cls_res_tmp, elapse = self.text_classifier(img)
if not rec: if not rec:
cls_res.append(cls_res_tmp) cls_res.append(cls_res_tmp)
rec_res, elapse = self.text_recognizer(img) if self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img)
else:
rec_res, elapse = self.text_recognizer(img)
ocr_res.append(rec_res) ocr_res.append(rec_res)
if not rec: if not rec:
return cls_res return cls_res
...@@ -121,7 +140,10 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -121,7 +140,10 @@ class ModifiedPaddleOCR(PaddleOCR):
start = time.time() start = time.time()
ori_im = img.copy() ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img) if self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
time_dict['det'] = elapse time_dict['det'] = elapse
if dt_boxes is None: if dt_boxes is None:
...@@ -159,8 +181,10 @@ class ModifiedPaddleOCR(PaddleOCR): ...@@ -159,8 +181,10 @@ class ModifiedPaddleOCR(PaddleOCR):
time_dict['cls'] = elapse time_dict['cls'] = elapse
logger.debug("cls num : {}, elapsed : {}".format( logger.debug("cls num : {}, elapsed : {}".format(
len(img_crop_list), elapse)) len(img_crop_list), elapse))
if self.use_onnx:
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
else:
rec_res, elapse = self.text_recognizer(img_crop_list)
time_dict['rec'] = elapse time_dict['rec'] = elapse
logger.debug("rec_res num : {}, elapsed : {}".format( logger.debug("rec_res num : {}, elapsed : {}".format(
len(rec_res), elapse)) len(rec_res), elapse))
......
import cv2 import cv2
import numpy as np import numpy as np
import torch
from loguru import logger from loguru import logger
from rapid_table import RapidTable from rapid_table import RapidTable
from rapidocr_paddle import RapidOCR
class RapidTableModel(object): class RapidTableModel(object):
...@@ -10,7 +10,12 @@ class RapidTableModel(object): ...@@ -10,7 +10,12 @@ class RapidTableModel(object):
self.table_model = RapidTable() self.table_model = RapidTable()
if ocr_engine is None: if ocr_engine is None:
self.ocr_model_name = "RapidOCR" self.ocr_model_name = "RapidOCR"
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True) if torch.cuda.is_available():
from rapidocr_paddle import RapidOCR
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
else:
from rapidocr_onnxruntime import RapidOCR
self.ocr_engine = RapidOCR()
else: else:
self.ocr_model_name = "PaddleOCR" self.ocr_model_name = "PaddleOCR"
self.ocr_engine = ocr_engine self.ocr_engine = ocr_engine
......
...@@ -5,6 +5,7 @@ from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text ...@@ -5,6 +5,7 @@ from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
from openai import OpenAI from openai import OpenAI
#@todo: 有的公式以"\"结尾,这样会导致尾部拼接的"$"被转义,也需要修复
formula_optimize_prompt = """请根据以下指南修正LaTeX公式的错误,确保公式能够渲染且符合原始内容: formula_optimize_prompt = """请根据以下指南修正LaTeX公式的错误,确保公式能够渲染且符合原始内容:
1. 修正渲染或编译错误: 1. 修正渲染或编译错误:
......
...@@ -50,6 +50,7 @@ if __name__ == '__main__': ...@@ -50,6 +50,7 @@ if __name__ == '__main__':
"accelerate", # struct-eqtable依赖 "accelerate", # struct-eqtable依赖
"doclayout_yolo==0.0.2", # doclayout_yolo "doclayout_yolo==0.0.2", # doclayout_yolo
"rapidocr-paddle", # rapidocr-paddle "rapidocr-paddle", # rapidocr-paddle
"rapidocr_onnxruntime",
"rapid_table", # rapid_table "rapid_table", # rapid_table
"PyYAML", # yaml "PyYAML", # yaml
"openai", # openai SDK "openai", # openai SDK
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment