Commit 878f3de0 authored by 赵小蒙's avatar 赵小蒙
Browse files

refactor(magic_pdf): optimize model initialization and threading

- Remove unnecessary threading.Lock in AtomModelSingleton
- Add threading.Lock to CustomPEKModel for OCR processing
- Simplify model initialization logic in AtomModelSingleton
parent 7ca7e599
......@@ -28,6 +28,8 @@ from magic_pdf.model.sub_modules.model_utils import (
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
get_adjusted_mfdetrec_res, get_ocr_result_list)
from threading import Lock
class CustomPEKModel:
......@@ -209,16 +211,18 @@ class CustomPEKModel:
# ocr识别
ocr_start = time.time()
# Process each area that requires OCR processing
lock = Lock()
for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
# OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
if self.apply_ocr:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
else:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
with lock:
if self.apply_ocr:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
else:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
# Integration results
if ocr_res:
......
......@@ -82,12 +82,9 @@ def ocr_model_init(show_log: bool = False,
return model
from threading import Lock
class AtomModelSingleton:
_instance = None
_models = {}
_lock = Lock()
def __new__(cls, *args, **kwargs):
if cls._instance is None:
......@@ -98,17 +95,10 @@ class AtomModelSingleton:
lang = kwargs.get('lang', None)
layout_model_name = kwargs.get('layout_model_name', None)
key = (atom_model_name, layout_model_name, lang)
if atom_model_name == AtomicModel.OCR:
with self._lock:
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
else:
return self._models[key]
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
else:
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
else:
return self._models[key]
return self._models[key]
def atom_model_init(model_name: str, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment