Commit 7eed5ee9 authored by myhloli's avatar myhloli
Browse files

refactor: streamline PDF parsing and enhance formula recognition handling

parent 1be66f98
...@@ -161,13 +161,13 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer ...@@ -161,13 +161,13 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
return page_info return page_info
def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=None, ocr_enable=False): def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=None, ocr_enable=False, formula_enabled=True):
middle_json = {"pdf_info": [], "_backend":"pipeline", "_version_name": __version__} middle_json = {"pdf_info": [], "_backend":"pipeline", "_version_name": __version__}
for page_index, page_model_info in enumerate(model_list): for page_index, page_model_info in enumerate(model_list):
page = pdf_doc[page_index] page = pdf_doc[page_index]
image_dict = images_list[page_index] image_dict = images_list[page_index]
page_info = page_model_info_to_page_info( page_info = page_model_info_to_page_info(
page_model_info, image_dict, page, image_writer, page_index, ocr_enable=ocr_enable page_model_info, image_dict, page, image_writer, page_index, ocr_enable=ocr_enable, formula_enabled=formula_enabled
) )
if page_info is None: if page_info is None:
page_w, page_h = map(int, page.get_size()) page_w, page_h = map(int, page.get_size())
......
...@@ -114,7 +114,7 @@ def do_parse( ...@@ -114,7 +114,7 @@ def do_parse(
pdf_doc = all_pdf_docs[idx] pdf_doc = all_pdf_docs[idx]
_lang = lang_list[idx] _lang = lang_list[idx]
_ocr_enable = ocr_enabled_list[idx] _ocr_enable = ocr_enabled_list[idx]
middle_json = pipeline_result_to_middle_json(model_list, images_list, pdf_doc, image_writer, _lang, _ocr_enable) middle_json = pipeline_result_to_middle_json(model_list, images_list, pdf_doc, image_writer, _lang, _ocr_enable, p_formula_enable)
pdf_info = middle_json["pdf_info"] pdf_info = middle_json["pdf_info"]
......
...@@ -84,7 +84,7 @@ def get_table_recog_config(): ...@@ -84,7 +84,7 @@ def get_table_recog_config():
if table_enable is not None: if table_enable is not None:
return json.loads(f'{{"enable": {table_enable}}}') return json.loads(f'{{"enable": {table_enable}}}')
else: else:
logger.warning(f"not found 'MINERU_TABLE_ENABLE' in environment variable, use 'true' as default.") # logger.warning(f"not found 'MINERU_TABLE_ENABLE' in environment variable, use 'true' as default.")
return json.loads(f'{{"enable": true}}') return json.loads(f'{{"enable": true}}')
...@@ -93,7 +93,7 @@ def get_formula_config(): ...@@ -93,7 +93,7 @@ def get_formula_config():
if formula_enable is not None: if formula_enable is not None:
return json.loads(f'{{"enable": {formula_enable}}}') return json.loads(f'{{"enable": {formula_enable}}}')
else: else:
logger.warning(f"not found 'MINERU_FORMULA_ENABLE' in environment variable, use 'true' as default.") # logger.warning(f"not found 'MINERU_FORMULA_ENABLE' in environment variable, use 'true' as default.")
return json.loads(f'{{"enable": true}}') return json.loads(f'{{"enable": true}}')
......
...@@ -9,13 +9,12 @@ import zipfile ...@@ -9,13 +9,12 @@ import zipfile
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import pymupdf
from gradio_pdf import PDF from gradio_pdf import PDF
from loguru import logger from loguru import logger
from magic_pdf.data.data_reader_writer import FileBasedDataReader from mineru.cli.common import prepare_env, do_parse
from magic_pdf.libs.hash_utils import compute_sha256 from mineru.data.data_reader_writer import FileBasedDataReader
from magic_pdf.tools.common import do_parse, prepare_env from mineru.utils.hash_utils import str_sha256
def read_fn(path): def read_fn(path):
...@@ -23,7 +22,7 @@ def read_fn(path): ...@@ -23,7 +22,7 @@ def read_fn(path):
return disk_rw.read(os.path.basename(path)) return disk_rw.read(os.path.basename(path))
def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, layout_mode, formula_enable, table_enable, language): def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, formula_enable, table_enable, language):
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
try: try:
...@@ -35,17 +34,14 @@ def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, layout_mode, formula_en ...@@ -35,17 +34,14 @@ def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, layout_mode, formula_en
parse_method = 'auto' parse_method = 'auto'
local_image_dir, local_md_dir = prepare_env(output_dir, file_name, parse_method) local_image_dir, local_md_dir = prepare_env(output_dir, file_name, parse_method)
do_parse( do_parse(
output_dir, output_dir=output_dir,
file_name, pdf_file_names=[file_name],
pdf_data, pdf_bytes_list=[pdf_data],
[], p_lang_list=[language],
parse_method, parse_method=parse_method,
True,
end_page_id=end_page_id, end_page_id=end_page_id,
layout_model=layout_mode, p_formula_enable=formula_enable,
formula_enable=formula_enable, p_table_enable=table_enable,
table_enable=table_enable,
lang=language,
) )
return local_md_dir, file_name return local_md_dir, file_name
except Exception as e: except Exception as e:
...@@ -96,12 +92,11 @@ def replace_image_with_base64(markdown_text, image_dir_path): ...@@ -96,12 +92,11 @@ def replace_image_with_base64(markdown_text, image_dir_path):
return re.sub(pattern, replace, markdown_text) return re.sub(pattern, replace, markdown_text)
def to_markdown(file_path, end_pages, is_ocr, layout_mode, formula_enable, table_enable, language): def to_markdown(file_path, end_pages, is_ocr, formula_enable, table_enable, language):
file_path = to_pdf(file_path) file_path = to_pdf(file_path)
# 获取识别的md文件以及压缩包文件路径 # 获取识别的md文件以及压缩包文件路径
local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr, local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr, formula_enable, table_enable, language)
layout_mode, formula_enable, table_enable, language) archive_zip_path = os.path.join('./output', str_sha256(local_md_dir) + '.zip')
archive_zip_path = os.path.join('./output', compute_sha256(local_md_dir) + '.zip')
zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path) zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
if zip_archive_success == 0: if zip_archive_success == 0:
logger.info('压缩成功') logger.info('压缩成功')
...@@ -126,13 +121,8 @@ latex_delimiters = [ ...@@ -126,13 +121,8 @@ latex_delimiters = [
def init_model(): def init_model():
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
try: try:
model_manager = ModelSingleton() pass
txt_model = model_manager.get_model(False, False) # noqa: F841
logger.info('txt_model init final')
ocr_model = model_manager.get_model(True, False) # noqa: F841
logger.info('ocr_model init final')
return 0 return 0
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
...@@ -172,23 +162,19 @@ all_lang.extend([*other_lang, *add_lang]) ...@@ -172,23 +162,19 @@ all_lang.extend([*other_lang, *add_lang])
def to_pdf(file_path): def to_pdf(file_path):
with pymupdf.open(file_path) as f: pdf_bytes = read_fn(file_path)
if f.is_pdf: # 将pdfbytes 写入到uuid.pdf中
return file_path # 生成唯一的文件名
else: unique_filename = f'{uuid.uuid4()}.pdf'
pdf_bytes = f.convert_to_pdf()
# 将pdfbytes 写入到uuid.pdf中
# 生成唯一的文件名
unique_filename = f'{uuid.uuid4()}.pdf'
# 构建完整的文件路径 # 构建完整的文件路径
tmp_file_path = os.path.join(os.path.dirname(file_path), unique_filename) tmp_file_path = os.path.join(os.path.dirname(file_path), unique_filename)
# 将字节数据写入文件 # 将字节数据写入文件
with open(tmp_file_path, 'wb') as tmp_pdf_file: with open(tmp_file_path, 'wb') as tmp_pdf_file:
tmp_pdf_file.write(pdf_bytes) tmp_pdf_file.write(pdf_bytes)
return tmp_file_path return tmp_file_path
if __name__ == '__main__': if __name__ == '__main__':
...@@ -199,11 +185,12 @@ if __name__ == '__main__': ...@@ -199,11 +185,12 @@ if __name__ == '__main__':
file = gr.File(label='Please upload a PDF or image', file_types=['.pdf', '.png', '.jpeg', '.jpg']) file = gr.File(label='Please upload a PDF or image', file_types=['.pdf', '.png', '.jpeg', '.jpg'])
max_pages = gr.Slider(1, 20, 10, step=1, label='Max convert pages') max_pages = gr.Slider(1, 20, 10, step=1, label='Max convert pages')
with gr.Row(): with gr.Row():
layout_mode = gr.Dropdown(['doclayout_yolo'], label='Layout model', value='doclayout_yolo') with gr.Column():
language = gr.Dropdown(all_lang, label='Language', value='ch') is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
with gr.Column():
language = gr.Dropdown(all_lang, label='Language', value='ch')
with gr.Row(): with gr.Row():
formula_enable = gr.Checkbox(label='Enable formula recognition', value=True) formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
table_enable = gr.Checkbox(label='Enable table recognition(test)', value=True) table_enable = gr.Checkbox(label='Enable table recognition(test)', value=True)
with gr.Row(): with gr.Row():
change_bu = gr.Button('Convert') change_bu = gr.Button('Convert')
...@@ -227,7 +214,7 @@ if __name__ == '__main__': ...@@ -227,7 +214,7 @@ if __name__ == '__main__':
with gr.Tab('Markdown text'): with gr.Tab('Markdown text'):
md_text = gr.TextArea(lines=45, show_copy_button=True) md_text = gr.TextArea(lines=45, show_copy_button=True)
file.change(fn=to_pdf, inputs=file, outputs=pdf_show) file.change(fn=to_pdf, inputs=file, outputs=pdf_show)
change_bu.click(fn=to_markdown, inputs=[file, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language], change_bu.click(fn=to_markdown, inputs=[file, max_pages, is_ocr, formula_enable, table_enable, language],
outputs=[md, md_text, output_file, pdf_show]) outputs=[md, md_text, output_file, pdf_show])
clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr]) clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment