"server/vscode:/vscode.git/clone" did not exist on "49a6c8c1b28742e806dd95a36af82db5b45d181d"
Unverified Commit 1c10dc55 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #1218 from myhloli/dev

refactor(magic-pdf): optimize model initialization and concurrency control
parents ef5cffcb 012a46e0
......@@ -143,10 +143,8 @@ def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
if lang == "":
lang = None
# model_manager = ModelSingleton()
# custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
custom_model = custom_model_init(ocr, show_log, lang, layout_model, formula_enable, table_enable)
model_manager = ModelSingleton()
custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
with fitz.open("pdf", pdf_bytes) as doc:
pdf_page_num = doc.page_count
......
......@@ -22,7 +22,7 @@ except ImportError:
from magic_pdf.config.constants import *
from magic_pdf.model.model_list import AtomicModel
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton, ocr_model_init
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
from magic_pdf.model.sub_modules.model_utils import (
clean_vram, crop_img, get_res_list_from_layout_res)
from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
......@@ -150,14 +150,9 @@ class CustomPEKModel:
device=self.device,
)
# 初始化ocr
# self.ocr_model = atom_model_manager.get_atom_model(
# atom_model_name=AtomicModel.OCR,
# ocr_show_log=show_log,
# det_db_box_thresh=0.3,
# lang=self.lang
# )
self.ocr_model = ocr_model_init(
show_log=show_log,
self.ocr_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.OCR,
ocr_show_log=show_log,
det_db_box_thresh=0.3,
lang=self.lang
)
......
......@@ -57,11 +57,6 @@ def doclayout_yolo_model_init(weight, device='cpu'):
return model
import threading
current_thread = threading.current_thread()
current_thread_id = current_thread.ident
def ocr_model_init(show_log: bool = False,
det_db_box_thresh=0.3,
lang=None,
......@@ -103,7 +98,7 @@ class AtomModelSingleton:
table_model_name = kwargs.get('table_model_name', None)
if atom_model_name in [AtomicModel.OCR]:
key = (atom_model_name, lang, current_thread_id)
key = (atom_model_name, lang)
elif atom_model_name in [AtomicModel.Layout]:
key = (atom_model_name, layout_model_name)
elif atom_model_name in [AtomicModel.Table]:
......
......@@ -152,7 +152,7 @@ def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=0.33):
return False
def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, ocr_model):
def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
......@@ -231,13 +231,13 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, ocr_
if len(empty_spans) > 0:
# 初始化ocr模型
# atom_model_manager = AtomModelSingleton()
# ocr_model = atom_model_manager.get_atom_model(
# atom_model_name='ocr',
# ocr_show_log=False,
# det_db_box_thresh=0.3,
# lang=lang
# )
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
for span in empty_spans:
# 对span的bbox截图再ocr
......@@ -613,7 +613,7 @@ def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
def parse_page_core(
page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, ocr_model
page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
):
need_drop = False
drop_reason = []
......@@ -682,7 +682,7 @@ def parse_page_core(
if parse_mode == SupportedPdfParseMethod.TXT:
"""使用新版本的混合ocr方案"""
spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, ocr_model)
spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, lang)
elif parse_mode == SupportedPdfParseMethod.OCR:
pass
......@@ -772,12 +772,6 @@ def pdf_parse_union(
lang=None,
):
ocr_model = ocr_model_init(
show_log=False,
det_db_box_thresh=0.3,
lang=lang
)
pdf_bytes_md5 = compute_md5(dataset.data_bits())
"""初始化空的pdf_info_dict"""
......@@ -813,7 +807,7 @@ def pdf_parse_union(
"""解析pdf中的每一页"""
if start_page_id <= page_id <= end_page_id:
page_info = parse_page_core(
page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, ocr_model
page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
)
else:
page_info = page.get_page_info()
......
......@@ -14,9 +14,7 @@ from gradio_pdf import PDF
from loguru import logger
from magic_pdf.data.data_reader_writer import FileBasedDataReader
from magic_pdf.libs.config_reader import get_device
from magic_pdf.libs.hash_utils import compute_sha256
from magic_pdf.model.sub_modules.model_utils import get_vram
from magic_pdf.tools.common import do_parse, prepare_env
......@@ -185,16 +183,6 @@ def to_pdf(file_path):
return tmp_file_path
def get_concurrency_limit(vram_threshold=7.5):
vram = get_vram(device = get_device())
if vram is not None and isinstance(vram, (int, float)):
concurrency_limit = max(1, int(vram // vram_threshold))
else:
concurrency_limit = 1
# logger.info(f'concurrency_limit: {concurrency_limit}')
return concurrency_limit
if __name__ == '__main__':
with gr.Blocks() as demo:
gr.HTML(header)
......@@ -231,7 +219,7 @@ if __name__ == '__main__':
md_text = gr.TextArea(lines=45, show_copy_button=True)
file.upload(fn=to_pdf, inputs=file, outputs=pdf_show)
change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language],
outputs=[md, md_text, output_file, pdf_show], concurrency_limit=get_concurrency_limit())
outputs=[md, md_text, output_file, pdf_show])
clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr, table_enable, language])
demo.launch(server_name='0.0.0.0')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment