Unverified Commit 0acfce29 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1214 from myhloli/dev

refactor(model): implement thread-safe OCR model initialization
parents ec5a09db f2a92d57
...@@ -22,7 +22,7 @@ except ImportError: ...@@ -22,7 +22,7 @@ except ImportError:
from magic_pdf.config.constants import * from magic_pdf.config.constants import *
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton, ocr_model_init from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res) clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
...@@ -150,8 +150,9 @@ class CustomPEKModel: ...@@ -150,8 +150,9 @@ class CustomPEKModel:
device=self.device, device=self.device,
) )
# 初始化ocr # 初始化ocr
self.ocr_model = ocr_model_init( self.ocr_model = atom_model_manager.get_atom_model(
show_log=show_log, atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3, det_db_box_thresh=0.3,
lang=self.lang lang=self.lang
) )
......
...@@ -57,6 +57,11 @@ def doclayout_yolo_model_init(weight, device='cpu'): ...@@ -57,6 +57,11 @@ def doclayout_yolo_model_init(weight, device='cpu'):
return model return model
import threading
current_thread = threading.current_thread()
current_thread_id = current_thread.ident
def ocr_model_init(show_log: bool = False, def ocr_model_init(show_log: bool = False,
det_db_box_thresh=0.3, det_db_box_thresh=0.3,
lang=None, lang=None,
...@@ -92,14 +97,24 @@ class AtomModelSingleton: ...@@ -92,14 +97,24 @@ class AtomModelSingleton:
return cls._instance return cls._instance
def get_atom_model(self, atom_model_name: str, **kwargs): def get_atom_model(self, atom_model_name: str, **kwargs):
lang = kwargs.get('lang', None) lang = kwargs.get('lang', None)
layout_model_name = kwargs.get('layout_model_name', None) layout_model_name = kwargs.get('layout_model_name', None)
key = (atom_model_name, layout_model_name, lang) table_model_name = kwargs.get('table_model_name', None)
if atom_model_name in [AtomicModel.OCR]:
key = (atom_model_name, lang, current_thread_id)
elif atom_model_name in [AtomicModel.Layout]:
key = (atom_model_name, layout_model_name)
elif atom_model_name in [AtomicModel.Table]:
key = (atom_model_name, table_model_name)
else:
key = atom_model_name
if key not in self._models: if key not in self._models:
self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs) self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
return self._models[key] return self._models[key]
def atom_model_init(model_name: str, **kwargs): def atom_model_init(model_name: str, **kwargs):
atom_model = None atom_model = None
if model_name == AtomicModel.Layout: if model_name == AtomicModel.Layout:
...@@ -129,7 +144,7 @@ def atom_model_init(model_name: str, **kwargs): ...@@ -129,7 +144,7 @@ def atom_model_init(model_name: str, **kwargs):
atom_model = ocr_model_init( atom_model = ocr_model_init(
kwargs.get('ocr_show_log'), kwargs.get('ocr_show_log'),
kwargs.get('det_db_box_thresh'), kwargs.get('det_db_box_thresh'),
kwargs.get('lang') kwargs.get('lang'),
) )
elif model_name == AtomicModel.Table: elif model_name == AtomicModel.Table:
atom_model = table_model_init( atom_model = table_model_init(
......
...@@ -31,7 +31,7 @@ try: ...@@ -31,7 +31,7 @@ try:
except ImportError: except ImportError:
pass pass
from magic_pdf.model.sub_modules.model_init import ocr_model_init from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.para.para_split_v3 import para_split from magic_pdf.para.para_split_v3 import para_split
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2 from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
...@@ -231,9 +231,10 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang ...@@ -231,9 +231,10 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
if len(empty_spans) > 0: if len(empty_spans) > 0:
# 初始化ocr模型 # 初始化ocr模型
atom_model_manager = AtomModelSingleton()
ocr_model = ocr_model_init( ocr_model = atom_model_manager.get_atom_model(
show_log=False, atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3, det_db_box_thresh=0.3,
lang=lang lang=lang
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment