Commit f2a92d57 authored by myhloli's avatar myhloli
Browse files

refactor(model): implement thread-safe OCR model initialization

- Add threading support for OCR model initialization
- Modify AtomModelSingleton to handle thread-specific instances
- Update PDFExtractKit and PDFParseUnionCoreV2 to use new thread-safe OCR initialization
parent ec5a09db
......@@ -22,7 +22,7 @@ except ImportError:
from magic_pdf.config.constants import *
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton, ocr_model_init
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
......@@ -150,8 +150,9 @@ class CustomPEKModel:
device=self.device,
)
# 初始化ocr
self.ocr_model = ocr_model_init(
show_log=show_log,
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3,
lang=self.lang
)
......
......@@ -57,6 +57,11 @@ def doclayout_yolo_model_init(weight, device='cpu'):
return model
import threading
current_thread = threading.current_thread()
current_thread_id = current_thread.ident
def ocr_model_init(show_log: bool = False,
det_db_box_thresh=0.3,
lang=None,
......@@ -92,14 +97,24 @@ class AtomModelSingleton:
return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs):
lang = kwargs.get('lang', None)
layout_model_name = kwargs.get('layout_model_name', None)
key = (atom_model_name, layout_model_name, lang)
table_model_name = kwargs.get('table_model_name', None)
if atom_model_name in [AtomicModel.OCR]:
key = (atom_model_name, lang, current_thread_id)
elif atom_model_name in [AtomicModel.Layout]:
key = (atom_model_name, layout_model_name)
elif atom_model_name in [AtomicModel.Table]:
key = (atom_model_name, table_model_name)
else:
key = atom_model_name
if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
return self._models[key]
def atom_model_init(model_name: str, **kwargs):
atom_model = None
if model_name == AtomicModel.Layout:
......@@ -129,7 +144,7 @@ def atom_model_init(model_name: str, **kwargs):
atom_model = ocr_model_init(
kwargs.get('ocr_show_log'),
kwargs.get('det_db_box_thresh'),
kwargs.get('lang')
kwargs.get('lang'),
)
elif model_name == AtomicModel.Table:
atom_model = table_model_init(
......
......@@ -31,7 +31,7 @@ try:
except ImportError:
pass
from magic_pdf.model.sub_modules.model_init import ocr_model_init
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.para.para_split_v3 import para_split
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
......@@ -231,9 +231,10 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
if len(empty_spans) > 0:
# 初始化ocr模型
ocr_model = ocr_model_init(
show_log=False,
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment