Commit 878f3de0 authored by 赵小蒙's avatar 赵小蒙
Browse files

refactor(magic_pdf): optimize model initialization and threading

- Remove unnecessary threading.Lock in AtomModelSingleton
- Add threading.Lock to CustomPEKModel for OCR processing
- Simplify model initialization logic in AtomModelSingleton
parent 7ca7e599
...@@ -28,6 +28,8 @@ from magic_pdf.model.sub_modules.model_utils import ( ...@@ -28,6 +28,8 @@ from magic_pdf.model.sub_modules.model_utils import (
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list) get_adjusted_mfdetrec_res, get_ocr_result_list)
from threading import Lock
class CustomPEKModel: class CustomPEKModel:
...@@ -209,16 +211,18 @@ class CustomPEKModel: ...@@ -209,16 +211,18 @@ class CustomPEKModel:
# ocr识别 # ocr识别
ocr_start = time.time() ocr_start = time.time()
# Process each area that requires OCR processing # Process each area that requires OCR processing
lock = Lock()
for res in ocr_res_list: for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50) new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list) adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
# OCR recognition # OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
if self.apply_ocr: with lock:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0] if self.apply_ocr:
else: ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0] else:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
# Integration results # Integration results
if ocr_res: if ocr_res:
......
...@@ -82,12 +82,9 @@ def ocr_model_init(show_log: bool = False, ...@@ -82,12 +82,9 @@ def ocr_model_init(show_log: bool = False,
return model return model
from threading import Lock
class AtomModelSingleton: class AtomModelSingleton:
_instance = None _instance = None
_models = {} _models = {}
_lock = Lock()
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if cls._instance is None: if cls._instance is None:
...@@ -98,17 +95,10 @@ class AtomModelSingleton: ...@@ -98,17 +95,10 @@ class AtomModelSingleton:
lang = kwargs.get('lang', None) lang = kwargs.get('lang', None)
layout_model_name = kwargs.get('layout_model_name', None) layout_model_name = kwargs.get('layout_model_name', None)
key = (atom_model_name, layout_model_name, lang) key = (atom_model_name, layout_model_name, lang)
if atom_model_name == AtomicModel.OCR: if key not in self._models:
with self._lock: self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
else:
return self._models[key]
else: else:
if key not in self._models: return self._models[key]
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
else:
return self._models[key]
def atom_model_init(model_name: str, **kwargs): def atom_model_init(model_name: str, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment