Unverified Commit 92c10d1e authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1208 from myhloli/dev

fix(multi-threading ):Enable multi-threading support for PaddleOCR.
parents 272014c4 30220233
...@@ -143,8 +143,10 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, ...@@ -143,8 +143,10 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
if lang == "": if lang == "":
lang = None lang = None
model_manager = ModelSingleton() # model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable) # custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
custom_model = custom_model_init(ocr, show_log, lang, layout_model, formula_enable, table_enable)
with fitz.open("pdf", pdf_bytes) as doc: with fitz.open("pdf", pdf_bytes) as doc:
pdf_page_num = doc.page_count pdf_page_num = doc.page_count
......
...@@ -22,7 +22,7 @@ except ImportError: ...@@ -22,7 +22,7 @@ except ImportError:
from magic_pdf.config.constants import * from magic_pdf.config.constants import *
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton, ocr_model_init
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res) clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
...@@ -37,6 +37,7 @@ class CustomPEKModel: ...@@ -37,6 +37,7 @@ class CustomPEKModel:
""" """
======== model init ======== ======== model init ========
""" """
self._lock = Lock()
# 获取当前文件(即 pdf_extract_kit.py)的绝对路径 # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
current_file_path = os.path.abspath(__file__) current_file_path = os.path.abspath(__file__)
# 获取当前文件所在的目录(model) # 获取当前文件所在的目录(model)
...@@ -152,9 +153,14 @@ class CustomPEKModel: ...@@ -152,9 +153,14 @@ class CustomPEKModel:
device=self.device, device=self.device,
) )
# 初始化ocr # 初始化ocr
self.ocr_model = atom_model_manager.get_atom_model( # self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR, # atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log, # ocr_show_log=show_log,
# det_db_box_thresh=0.3,
# lang=self.lang
# )
self.ocr_model = ocr_model_init(
show_log=show_log,
det_db_box_thresh=0.3, det_db_box_thresh=0.3,
lang=self.lang lang=self.lang
) )
...@@ -211,18 +217,17 @@ class CustomPEKModel: ...@@ -211,18 +217,17 @@ class CustomPEKModel:
# ocr识别 # ocr识别
ocr_start = time.time() ocr_start = time.time()
# Process each area that requires OCR processing # Process each area that requires OCR processing
lock = Lock()
for res in ocr_res_list: for res in ocr_res_list:
new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50) new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list) adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
# OCR recognition # OCR recognition
new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR) new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
with lock: # with self._lock:
if self.apply_ocr: if self.apply_ocr:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0] ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
else: else:
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0] ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
# Integration results # Integration results
if ocr_res: if ocr_res:
......
...@@ -31,7 +31,7 @@ try: ...@@ -31,7 +31,7 @@ try:
except ImportError: except ImportError:
pass pass
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton, ocr_model_init
from magic_pdf.para.para_split_v3 import para_split from magic_pdf.para.para_split_v3 import para_split
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2 from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
...@@ -231,10 +231,15 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang ...@@ -231,10 +231,15 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
if len(empty_spans) > 0: if len(empty_spans) > 0:
# 初始化ocr模型 # 初始化ocr模型
atom_model_manager = AtomModelSingleton() # atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model( # ocr_model = atom_model_manager.get_atom_model(
atom_model_name="ocr", # atom_model_name="ocr",
ocr_show_log=False, # ocr_show_log=False,
# det_db_box_thresh=0.3,
# lang=lang
# )
ocr_model = ocr_model_init(
show_log=False,
det_db_box_thresh=0.3, det_db_box_thresh=0.3,
lang=lang lang=lang
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment