Commit 47a83d28 authored by myhloli's avatar myhloli
Browse files

refactor(ocr): replace AtomModelSingleton with ocr_model_init for OCR model instantiation

- Remove usage of AtomModelSingleton for OCR model creation
- Add ocr_model_init function to initialize OCR model
- Update OCR model initialization in pdf_extract_kit.py and pdf_parse_union_core_v2.py
- Modify txt_spans_extract_v2 function to accept ocr_model as a parameter
- Update parse_page_core function to use ocr_model instead of lang for OCR processing
parent f2a92d57
...@@ -22,7 +22,7 @@ except ImportError: ...@@ -22,7 +22,7 @@ except ImportError:
from magic_pdf.config.constants import * from magic_pdf.config.constants import *
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton, ocr_model_init
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res) clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
...@@ -150,9 +150,14 @@ class CustomPEKModel: ...@@ -150,9 +150,14 @@ class CustomPEKModel:
device=self.device, device=self.device,
) )
# 初始化ocr # 初始化ocr
self.ocr_model = atom_model_manager.get_atom_model( # self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR, # atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log, # ocr_show_log=show_log,
# det_db_box_thresh=0.3,
# lang=self.lang
# )
self.ocr_model = ocr_model_init(
show_log=show_log,
det_db_box_thresh=0.3, det_db_box_thresh=0.3,
lang=self.lang lang=self.lang
) )
......
...@@ -31,7 +31,7 @@ try: ...@@ -31,7 +31,7 @@ try:
except ImportError: except ImportError:
pass pass
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton from magic_pdf.model.sub_modules.model_init import AtomModelSingleton, ocr_model_init
from magic_pdf.para.para_split_v3 import para_split from magic_pdf.para.para_split_v3 import para_split
from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2 from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2
from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table from magic_pdf.pre_proc.cut_image import ocr_cut_image_and_table
...@@ -152,7 +152,7 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33): ...@@ -152,7 +152,7 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
return False return False
def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang): def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, ocr_model):
text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks'] text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
...@@ -231,13 +231,13 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang ...@@ -231,13 +231,13 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
if len(empty_spans) > 0: if len(empty_spans) > 0:
# 初始化ocr模型 # 初始化ocr模型
atom_model_manager = AtomModelSingleton() # atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model( # ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr', # atom_model_name='ocr',
ocr_show_log=False, # ocr_show_log=False,
det_db_box_thresh=0.3, # det_db_box_thresh=0.3,
lang=lang # lang=lang
) # )
for span in empty_spans: for span in empty_spans:
# 对span的bbox截图再ocr # 对span的bbox截图再ocr
...@@ -613,7 +613,7 @@ def remove_outside_spans(spans, all_bboxes, all_discarded_blocks): ...@@ -613,7 +613,7 @@ def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
def parse_page_core( def parse_page_core(
page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, ocr_model
): ):
need_drop = False need_drop = False
drop_reason = [] drop_reason = []
...@@ -682,7 +682,7 @@ def parse_page_core( ...@@ -682,7 +682,7 @@ def parse_page_core(
if parse_mode == SupportedPdfParseMethod.TXT: if parse_mode == SupportedPdfParseMethod.TXT:
"""使用新版本的混合ocr方案""" """使用新版本的混合ocr方案"""
spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, lang) spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, ocr_model)
elif parse_mode == SupportedPdfParseMethod.OCR: elif parse_mode == SupportedPdfParseMethod.OCR:
pass pass
...@@ -771,6 +771,13 @@ def pdf_parse_union( ...@@ -771,6 +771,13 @@ def pdf_parse_union(
debug_mode=False, debug_mode=False,
lang=None, lang=None,
): ):
ocr_model = ocr_model_init(
show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
pdf_bytes_md5 = compute_md5(dataset.data_bits()) pdf_bytes_md5 = compute_md5(dataset.data_bits())
"""初始化空的pdf_info_dict""" """初始化空的pdf_info_dict"""
...@@ -806,7 +813,7 @@ def pdf_parse_union( ...@@ -806,7 +813,7 @@ def pdf_parse_union(
"""解析pdf中的每一页""" """解析pdf中的每一页"""
if start_page_id <= page_id <= end_page_id: if start_page_id <= page_id <= end_page_id:
page_info = parse_page_core( page_info = parse_page_core(
page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, ocr_model
) )
else: else:
page_info = page.get_page_info() page_info = page.get_page_info()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment