Commit 512adb67 authored by myhloli's avatar myhloli
Browse files

feat(model): add onnxruntime support for paddleocr on cpu

- Implement ONNXModelSingleton to manage ONNX models
- Modify ModifiedPaddleOCR to use ONNX models on ARM CPUs without CUDA
- Update RapidTableModel to use RapidOCR with ONNXRuntime on CPU
- Add rapidocr_onnxruntime dependency in setup.py
parent 7c5cdcd4
......@@ -7,7 +7,7 @@
"layoutreader-model-dir":"/tmp/layoutreader",
"device-mode":"cpu",
"layout-config": {
"model": "layoutlmv3"
"model": "doclayout_yolo"
},
"formula-config": {
"mfd_model": "yolo_v8_mfd",
......
......@@ -70,11 +70,6 @@ def ocr_model_init(show_log: bool = False,
det_db_unclip_ratio=1.8,
):
# use_npu = False
# device = get_device()
# if str(device).startswith("npu"):
# use_npu = True
if lang is not None and lang != '':
model = ModifiedPaddleOCR(
show_log=show_log,
......@@ -82,7 +77,6 @@ def ocr_model_init(show_log: bool = False,
lang=lang,
use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio,
# use_npu=use_npu,
)
else:
model = ModifiedPaddleOCR(
......@@ -90,7 +84,6 @@ def ocr_model_init(show_log: bool = False,
det_db_box_thresh=det_db_box_thresh,
use_dilation=use_dilation,
det_db_unclip_ratio=det_db_unclip_ratio,
# use_npu=use_npu,
)
return model
......@@ -160,6 +153,7 @@ def atom_model_init(model_name: str, **kwargs):
kwargs.get('table_model_path'),
kwargs.get('table_max_time'),
kwargs.get('device'),
kwargs.get('ocr_engine')
)
else:
logger.error('model name not allow')
......
......@@ -303,4 +303,54 @@ def calculate_is_angle(poly):
return False
else:
# logger.info((p3[1] - p1[1])/height)
return True
\ No newline at end of file
return True
class ONNXModelSingleton:
_instance = None
_models = {}
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_onnx_model(self, **kwargs):
lang = kwargs.get('lang', None)
det_db_box_thresh = kwargs.get('det_db_box_thresh', 0.3)
use_dilation = kwargs.get('use_dilation', True)
det_db_unclip_ratio = kwargs.get('det_db_unclip_ratio', 1.8)
key = (lang, det_db_box_thresh, use_dilation, det_db_unclip_ratio)
if key not in self._models:
self._models[key] = onnx_model_init(key)
return self._models[key]
def onnx_model_init(key):
import importlib.resources
resource_path = importlib.resources.path('rapidocr_onnxruntime.models','')
onnx_model = None
additional_ocr_params = {
"use_onnx": True,
"det_model_dir": f'{resource_path}/ch_PP-OCRv4_det_infer.onnx',
"rec_model_dir": f'{resource_path}/ch_PP-OCRv4_rec_infer.onnx',
"cls_model_dir": f'{resource_path}/ch_ppocr_mobile_v2.0_cls_infer.onnx',
"det_db_box_thresh": key[1],
"use_dilation": key[2],
"det_db_unclip_ratio": key[3],
}
logger.info(f"additional_ocr_params: {additional_ocr_params}")
if key[0] is not None:
additional_ocr_params["lang"] = key[0]
from paddleocr import PaddleOCR
onnx_model = PaddleOCR(**additional_ocr_params)
if onnx_model is None:
logger.error('model init failed')
exit(1)
else:
return onnx_model
\ No newline at end of file
import copy
import platform
import time
import cv2
import numpy as np
import torch
from paddleocr import PaddleOCR
from ppocr.utils.logging import get_logger
......@@ -9,12 +11,23 @@ from ppocr.utils.utility import alpha_to_color, binarize_img
from tools.infer.predict_system import sorted_boxes
from tools.infer.utility import get_rotate_crop_image, get_minarea_rect_crop
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import update_det_boxes, merge_det_boxes, check_img, \
ONNXModelSingleton
logger = get_logger()
class ModifiedPaddleOCR(PaddleOCR):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# 在cpu架构为arm且不支持cuda时调用onnx、
if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
self.use_onnx = True
onnx_model_manager = ONNXModelSingleton()
self.additional_ocr = onnx_model_manager.get_onnx_model(**kwargs)
def ocr(self,
img,
det=True,
......@@ -79,7 +92,10 @@ class ModifiedPaddleOCR(PaddleOCR):
ocr_res = []
for img in imgs:
img = preprocess_image(img)
dt_boxes, elapse = self.text_detector(img)
if self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
if dt_boxes is None:
ocr_res.append(None)
continue
......@@ -106,7 +122,10 @@ class ModifiedPaddleOCR(PaddleOCR):
img, cls_res_tmp, elapse = self.text_classifier(img)
if not rec:
cls_res.append(cls_res_tmp)
rec_res, elapse = self.text_recognizer(img)
if self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img)
else:
rec_res, elapse = self.text_recognizer(img)
ocr_res.append(rec_res)
if not rec:
return cls_res
......@@ -121,7 +140,10 @@ class ModifiedPaddleOCR(PaddleOCR):
start = time.time()
ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img)
if self.use_onnx:
dt_boxes, elapse = self.additional_ocr.text_detector(img)
else:
dt_boxes, elapse = self.text_detector(img)
time_dict['det'] = elapse
if dt_boxes is None:
......@@ -159,8 +181,10 @@ class ModifiedPaddleOCR(PaddleOCR):
time_dict['cls'] = elapse
logger.debug("cls num : {}, elapsed : {}".format(
len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list)
if self.use_onnx:
rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
else:
rec_res, elapse = self.text_recognizer(img_crop_list)
time_dict['rec'] = elapse
logger.debug("rec_res num : {}, elapsed : {}".format(
len(rec_res), elapse))
......
import cv2
import numpy as np
import torch
from loguru import logger
from rapid_table import RapidTable
from rapidocr_paddle import RapidOCR
class RapidTableModel(object):
......@@ -10,7 +10,12 @@ class RapidTableModel(object):
self.table_model = RapidTable()
if ocr_engine is None:
self.ocr_model_name = "RapidOCR"
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
if torch.cuda.is_available():
from rapidocr_paddle import RapidOCR
self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
else:
from rapidocr_onnxruntime import RapidOCR
self.ocr_engine = RapidOCR()
else:
self.ocr_model_name = "PaddleOCR"
self.ocr_engine = ocr_engine
......
......@@ -5,6 +5,7 @@ from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
from openai import OpenAI
#@todo: 有的公式以"\"结尾,这样会导致尾部拼接的"$"被转义,也需要修复
formula_optimize_prompt = """请根据以下指南修正LaTeX公式的错误,确保公式能够渲染且符合原始内容:
1. 修正渲染或编译错误:
......
......@@ -50,6 +50,7 @@ if __name__ == '__main__':
"accelerate", # struct-eqtable依赖
"doclayout_yolo==0.0.2", # doclayout_yolo
"rapidocr-paddle", # rapidocr-paddle
"rapidocr_onnxruntime",
"rapid_table", # rapid_table
"PyYAML", # yaml
"openai", # openai SDK
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment