"dev/vscode:/vscode.git/clone" did not exist on "c732df654546c5a17797464335524390a1e865e5"
Commit 012a46e0 authored by myhloli's avatar myhloli
Browse files

refactor(magic-pdf): optimize model initialization and concurrency control

- Remove concurrency limit logic from app.py
- Update model initialization process in various modules
- Remove unused VRAM check for concurrency limit
- Refactor OCR model initialization in pdf_extract_kit.py
- Update txt_spans_extract_v2 function to use lang parameter instead of ocr_model
parent 47a83d28
...@@ -143,10 +143,8 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False, ...@@ -143,10 +143,8 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
if lang == "": if lang == "":
lang = None lang = None
# model_manager = ModelSingleton() model_manager = ModelSingleton()
# custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable) custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
custom_model = custom_model_init(ocr, show_log, lang, layout_model, formula_enable, table_enable)
with fitz.open("pdf", pdf_bytes) as doc: with fitz.open("pdf", pdf_bytes) as doc:
pdf_page_num = doc.page_count pdf_page_num = doc.page_count
......
...@@ -22,7 +22,7 @@ except ImportError: ...@@ -22,7 +22,7 @@ except ImportError:
from magic_pdf.config.constants import * from magic_pdf.config.constants import *
from magic_pdf.model.model_list import AtomicModel from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton, ocr_model_init from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import ( from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res) clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import ( from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
...@@ -150,14 +150,9 @@ class CustomPEKModel: ...@@ -150,14 +150,9 @@ class CustomPEKModel:
device=self.device, device=self.device,
) )
# 初始化ocr # 初始化ocr
# self.ocr_model = atom_model_manager.get_atom_model( self.ocr_model = atom_model_manager.get_atom_model(
# atom_model_name=AtomicModel.OCR, atom_model_name=AtomicModel.OCR,
# ocr_show_log=show_log, ocr_show_log=show_log,
# det_db_box_thresh=0.3,
# lang=self.lang
# )
self.ocr_model = ocr_model_init(
show_log=show_log,
det_db_box_thresh=0.3, det_db_box_thresh=0.3,
lang=self.lang lang=self.lang
) )
......
...@@ -57,11 +57,6 @@ def doclayout_yolo_model_init(weight, device='cpu'): ...@@ -57,11 +57,6 @@ def doclayout_yolo_model_init(weight, device='cpu'):
return model return model
import threading
current_thread = threading.current_thread()
current_thread_id = current_thread.ident
def ocr_model_init(show_log: bool = False, def ocr_model_init(show_log: bool = False,
det_db_box_thresh=0.3, det_db_box_thresh=0.3,
lang=None, lang=None,
...@@ -103,7 +98,7 @@ class AtomModelSingleton: ...@@ -103,7 +98,7 @@ class AtomModelSingleton:
table_model_name = kwargs.get('table_model_name', None) table_model_name = kwargs.get('table_model_name', None)
if atom_model_name in [AtomicModel.OCR]: if atom_model_name in [AtomicModel.OCR]:
key = (atom_model_name, lang, current_thread_id) key = (atom_model_name, lang)
elif atom_model_name in [AtomicModel.Layout]: elif atom_model_name in [AtomicModel.Layout]:
key = (atom_model_name, layout_model_name) key = (atom_model_name, layout_model_name)
elif atom_model_name in [AtomicModel.Table]: elif atom_model_name in [AtomicModel.Table]:
......
...@@ -152,7 +152,7 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33): ...@@ -152,7 +152,7 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
return False return False
def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, ocr_model): def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks'] text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
...@@ -231,13 +231,13 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, ocr_ ...@@ -231,13 +231,13 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, ocr_
if len(empty_spans) > 0: if len(empty_spans) > 0:
# 初始化ocr模型 # 初始化ocr模型
# atom_model_manager = AtomModelSingleton() atom_model_manager = AtomModelSingleton()
# ocr_model = atom_model_manager.get_atom_model( ocr_model = atom_model_manager.get_atom_model(
# atom_model_name='ocr', atom_model_name='ocr',
# ocr_show_log=False, ocr_show_log=False,
# det_db_box_thresh=0.3, det_db_box_thresh=0.3,
# lang=lang lang=lang
# ) )
for span in empty_spans: for span in empty_spans:
# 对span的bbox截图再ocr # 对span的bbox截图再ocr
...@@ -613,7 +613,7 @@ def remove_outside_spans(spans, all_bboxes, all_discarded_blocks): ...@@ -613,7 +613,7 @@ def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
def parse_page_core( def parse_page_core(
page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, ocr_model page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
): ):
need_drop = False need_drop = False
drop_reason = [] drop_reason = []
...@@ -682,7 +682,7 @@ def parse_page_core( ...@@ -682,7 +682,7 @@ def parse_page_core(
if parse_mode == SupportedPdfParseMethod.TXT: if parse_mode == SupportedPdfParseMethod.TXT:
"""使用新版本的混合ocr方案""" """使用新版本的混合ocr方案"""
spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, ocr_model) spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, lang)
elif parse_mode == SupportedPdfParseMethod.OCR: elif parse_mode == SupportedPdfParseMethod.OCR:
pass pass
...@@ -772,12 +772,6 @@ def pdf_parse_union( ...@@ -772,12 +772,6 @@ def pdf_parse_union(
lang=None, lang=None,
): ):
ocr_model = ocr_model_init(
show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
pdf_bytes_md5 = compute_md5(dataset.data_bits()) pdf_bytes_md5 = compute_md5(dataset.data_bits())
"""初始化空的pdf_info_dict""" """初始化空的pdf_info_dict"""
...@@ -813,7 +807,7 @@ def pdf_parse_union( ...@@ -813,7 +807,7 @@ def pdf_parse_union(
"""解析pdf中的每一页""" """解析pdf中的每一页"""
if start_page_id <= page_id <= end_page_id: if start_page_id <= page_id <= end_page_id:
page_info = parse_page_core( page_info = parse_page_core(
page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, ocr_model page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
) )
else: else:
page_info = page.get_page_info() page_info = page.get_page_info()
......
...@@ -14,9 +14,7 @@ from gradio_pdf import PDF ...@@ -14,9 +14,7 @@ from gradio_pdf import PDF
from loguru import logger from loguru import logger
from magic_pdf.data.data_reader_writer import FileBasedDataReader from magic_pdf.data.data_reader_writer import FileBasedDataReader
from magic_pdf.libs.config_reader import get_device
from magic_pdf.libs.hash_utils import compute_sha256 from magic_pdf.libs.hash_utils import compute_sha256
from magic_pdf.model.sub_modules.model_utils import get_vram
from magic_pdf.tools.common import do_parse, prepare_env from magic_pdf.tools.common import do_parse, prepare_env
...@@ -185,16 +183,6 @@ def to_pdf(file_path): ...@@ -185,16 +183,6 @@ def to_pdf(file_path):
return tmp_file_path return tmp_file_path
def get_concurrency_limit(vram_threshold=7.5):
vram = get_vram(device = get_device())
if vram is not None and isinstance(vram, (int, float)):
concurrency_limit = max(1, int(vram // vram_threshold))
else:
concurrency_limit = 1
# logger.info(f'concurrency_limit: {concurrency_limit}')
return concurrency_limit
if __name__ == '__main__': if __name__ == '__main__':
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.HTML(header) gr.HTML(header)
...@@ -231,7 +219,7 @@ if __name__ == '__main__': ...@@ -231,7 +219,7 @@ if __name__ == '__main__':
md_text = gr.TextArea(lines=45, show_copy_button=True) md_text = gr.TextArea(lines=45, show_copy_button=True)
file.upload(fn=to_pdf, inputs=file, outputs=pdf_show) file.upload(fn=to_pdf, inputs=file, outputs=pdf_show)
change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language], change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language],
outputs=[md, md_text, output_file, pdf_show], concurrency_limit=get_concurrency_limit()) outputs=[md, md_text, output_file, pdf_show])
clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr, table_enable, language]) clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr, table_enable, language])
demo.launch(server_name='0.0.0.0') demo.launch(server_name='0.0.0.0')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment