Unverified Commit 97000c0b authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2854 from myhloli/dev

Dev
parents a76e3b60 5d4263d4
from mineru.utils.config_reader import get_latex_delimiter_config import os
from mineru.utils.config_reader import get_latex_delimiter_config, get_formula_enable, get_table_enable
from mineru.utils.enum_class import MakeMode, BlockType, ContentType from mineru.utils.enum_class import MakeMode, BlockType, ContentType
...@@ -16,7 +18,7 @@ display_right_delimiter = delimiters['display']['right'] ...@@ -16,7 +18,7 @@ display_right_delimiter = delimiters['display']['right']
inline_left_delimiter = delimiters['inline']['left'] inline_left_delimiter = delimiters['inline']['left']
inline_right_delimiter = delimiters['inline']['right'] inline_right_delimiter = delimiters['inline']['right']
def merge_para_with_text(para_block): def merge_para_with_text(para_block, formula_enable=True, img_buket_path=''):
para_text = '' para_text = ''
for line in para_block['lines']: for line in para_block['lines']:
for j, span in enumerate(line['spans']): for j, span in enumerate(line['spans']):
...@@ -27,7 +29,11 @@ def merge_para_with_text(para_block): ...@@ -27,7 +29,11 @@ def merge_para_with_text(para_block):
elif span_type == ContentType.INLINE_EQUATION: elif span_type == ContentType.INLINE_EQUATION:
content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}" content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}"
elif span_type == ContentType.INTERLINE_EQUATION: elif span_type == ContentType.INTERLINE_EQUATION:
content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n" if formula_enable:
content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
else:
if span.get('image_path', ''):
content = f"![]({img_buket_path}/{span['image_path']})"
# content = content.strip() # content = content.strip()
if content: if content:
if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]: if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
...@@ -39,13 +45,13 @@ def merge_para_with_text(para_block): ...@@ -39,13 +45,13 @@ def merge_para_with_text(para_block):
para_text += content para_text += content
return para_text return para_text
def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''): def mk_blocks_to_markdown(para_blocks, make_mode, formula_enable, table_enable, img_buket_path=''):
page_markdown = [] page_markdown = []
for para_block in para_blocks: for para_block in para_blocks:
para_text = '' para_text = ''
para_type = para_block['type'] para_type = para_block['type']
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.INTERLINE_EQUATION]: if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.INTERLINE_EQUATION]:
para_text = merge_para_with_text(para_block) para_text = merge_para_with_text(para_block, formula_enable=formula_enable, img_buket_path=img_buket_path)
elif para_type == BlockType.TITLE: elif para_type == BlockType.TITLE:
title_level = get_title_level(para_block) title_level = get_title_level(para_block)
para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}' para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
...@@ -95,10 +101,14 @@ def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''): ...@@ -95,10 +101,14 @@ def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''):
for span in line['spans']: for span in line['spans']:
if span['type'] == ContentType.TABLE: if span['type'] == ContentType.TABLE:
# if processed by table model # if processed by table model
if span.get('html', ''): if table_enable:
para_text += f"\n{span['html']}\n" if span.get('html', ''):
elif span.get('image_path', ''): para_text += f"\n{span['html']}\n"
para_text += f"![]({img_buket_path}/{span['image_path']})" elif span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
else:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼table_footnote for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TABLE_FOOTNOTE: if block['type'] == BlockType.TABLE_FOOTNOTE:
para_text += '\n' + merge_para_with_text(block) + ' ' para_text += '\n' + merge_para_with_text(block) + ' '
...@@ -177,6 +187,10 @@ def union_make(pdf_info_dict: list, ...@@ -177,6 +187,10 @@ def union_make(pdf_info_dict: list,
make_mode: str, make_mode: str,
img_buket_path: str = '', img_buket_path: str = '',
): ):
formula_enable = get_formula_enable(os.getenv('MINERU_VLM_FORMULA_ENABLE', 'True').lower() == 'true')
table_enable = get_table_enable(os.getenv('MINERU_VLM_TABLE_ENABLE', 'True').lower() == 'true')
output_content = [] output_content = []
for page_info in pdf_info_dict: for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks') paras_of_layout = page_info.get('para_blocks')
...@@ -184,7 +198,7 @@ def union_make(pdf_info_dict: list, ...@@ -184,7 +198,7 @@ def union_make(pdf_info_dict: list,
if not paras_of_layout: if not paras_of_layout:
continue continue
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]: if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
page_markdown = mk_blocks_to_markdown(paras_of_layout, make_mode, img_buket_path) page_markdown = mk_blocks_to_markdown(paras_of_layout, make_mode, formula_enable, table_enable, img_buket_path)
output_content.extend(page_markdown) output_content.extend(page_markdown)
elif make_mode == MakeMode.CONTENT_LIST: elif make_mode == MakeMode.CONTENT_LIST:
for para_block in paras_of_layout: for para_block in paras_of_layout:
......
...@@ -180,8 +180,8 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i ...@@ -180,8 +180,8 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i
p_lang_list=lang_list, p_lang_list=lang_list,
backend=backend, backend=backend,
parse_method=method, parse_method=method,
p_formula_enable=formula_enable, formula_enable=formula_enable,
p_table_enable=table_enable, table_enable=table_enable,
server_url=server_url, server_url=server_url,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id end_page_id=end_page_id
......
...@@ -298,8 +298,8 @@ def do_parse( ...@@ -298,8 +298,8 @@ def do_parse(
p_lang_list: list[str], p_lang_list: list[str],
backend="pipeline", backend="pipeline",
parse_method="auto", parse_method="auto",
p_formula_enable=True, formula_enable=True,
p_table_enable=True, table_enable=True,
server_url=None, server_url=None,
f_draw_layout_bbox=True, f_draw_layout_bbox=True,
f_draw_span_bbox=True, f_draw_span_bbox=True,
...@@ -318,7 +318,7 @@ def do_parse( ...@@ -318,7 +318,7 @@ def do_parse(
if backend == "pipeline": if backend == "pipeline":
_process_pipeline( _process_pipeline(
output_dir, pdf_file_names, pdf_bytes_list, p_lang_list, output_dir, pdf_file_names, pdf_bytes_list, p_lang_list,
parse_method, p_formula_enable, p_table_enable, parse_method, formula_enable, table_enable,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json, f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode
) )
...@@ -326,6 +326,9 @@ def do_parse( ...@@ -326,6 +326,9 @@ def do_parse(
if backend.startswith("vlm-"): if backend.startswith("vlm-"):
backend = backend[4:] backend = backend[4:]
os.environ['MINERU_VLM_FORMULA_ENABLE'] = str(formula_enable)
os.environ['MINERU_VLM_TABLE_ENABLE'] = str(table_enable)
_process_vlm( _process_vlm(
output_dir, pdf_file_names, pdf_bytes_list, backend, output_dir, pdf_file_names, pdf_bytes_list, backend,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json, f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
...@@ -341,8 +344,8 @@ async def aio_do_parse( ...@@ -341,8 +344,8 @@ async def aio_do_parse(
p_lang_list: list[str], p_lang_list: list[str],
backend="pipeline", backend="pipeline",
parse_method="auto", parse_method="auto",
p_formula_enable=True, formula_enable=True,
p_table_enable=True, table_enable=True,
server_url=None, server_url=None,
f_draw_layout_bbox=True, f_draw_layout_bbox=True,
f_draw_span_bbox=True, f_draw_span_bbox=True,
...@@ -362,7 +365,7 @@ async def aio_do_parse( ...@@ -362,7 +365,7 @@ async def aio_do_parse(
# pipeline模式暂不支持异步,使用同步处理方式 # pipeline模式暂不支持异步,使用同步处理方式
_process_pipeline( _process_pipeline(
output_dir, pdf_file_names, pdf_bytes_list, p_lang_list, output_dir, pdf_file_names, pdf_bytes_list, p_lang_list,
parse_method, p_formula_enable, p_table_enable, parse_method, formula_enable, table_enable,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json, f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode
) )
...@@ -370,6 +373,9 @@ async def aio_do_parse( ...@@ -370,6 +373,9 @@ async def aio_do_parse(
if backend.startswith("vlm-"): if backend.startswith("vlm-"):
backend = backend[4:] backend = backend[4:]
os.environ['MINERU_VLM_FORMULA_ENABLE'] = str(formula_enable)
os.environ['MINERU_VLM_TABLE_ENABLE'] = str(table_enable)
await _async_process_vlm( await _async_process_vlm(
output_dir, pdf_file_names, pdf_bytes_list, backend, output_dir, pdf_file_names, pdf_bytes_list, backend,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json, f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
......
...@@ -93,8 +93,8 @@ async def parse_pdf( ...@@ -93,8 +93,8 @@ async def parse_pdf(
p_lang_list=actual_lang_list, p_lang_list=actual_lang_list,
backend=backend, backend=backend,
parse_method=parse_method, parse_method=parse_method,
p_formula_enable=formula_enable, formula_enable=formula_enable,
p_table_enable=table_enable, table_enable=table_enable,
server_url=server_url, server_url=server_url,
f_draw_layout_bbox=False, f_draw_layout_bbox=False,
f_draw_span_bbox=False, f_draw_span_bbox=False,
......
...@@ -38,8 +38,8 @@ async def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, formula_enable, t ...@@ -38,8 +38,8 @@ async def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, formula_enable, t
p_lang_list=[language], p_lang_list=[language],
parse_method=parse_method, parse_method=parse_method,
end_page_id=end_page_id, end_page_id=end_page_id,
p_formula_enable=formula_enable, formula_enable=formula_enable,
p_table_enable=table_enable, table_enable=table_enable,
backend=backend, backend=backend,
server_url=url, server_url=url,
) )
...@@ -179,11 +179,11 @@ def to_pdf(file_path): ...@@ -179,11 +179,11 @@ def to_pdf(file_path):
# 更新界面函数 # 更新界面函数
def update_interface(backend_choice): def update_interface(backend_choice):
if backend_choice in ["vlm-transformers", "vlm-sglang-engine"]: if backend_choice in ["vlm-transformers", "vlm-sglang-engine"]:
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) return gr.update(visible=False), gr.update(visible=False)
elif backend_choice in ["vlm-sglang-client"]: # pipeline elif backend_choice in ["vlm-sglang-client"]:
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) return gr.update(visible=True), gr.update(visible=False)
elif backend_choice in ["pipeline"]: elif backend_choice in ["pipeline"]:
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) return gr.update(visible=False), gr.update(visible=True)
else: else:
pass pass
...@@ -230,7 +230,7 @@ def main(example_enable, sglang_engine_enable, mem_fraction_static, torch_compil ...@@ -230,7 +230,7 @@ def main(example_enable, sglang_engine_enable, mem_fraction_static, torch_compil
try: try:
print("Start init SgLang engine...") print("Start init SgLang engine...")
from mineru.backend.vlm.vlm_analyze import ModelSingleton from mineru.backend.vlm.vlm_analyze import ModelSingleton
modelsingleton = ModelSingleton() model_singleton = ModelSingleton()
model_params = { model_params = {
"enable_torch_compile": torch_compile_enable "enable_torch_compile": torch_compile_enable
...@@ -239,7 +239,7 @@ def main(example_enable, sglang_engine_enable, mem_fraction_static, torch_compil ...@@ -239,7 +239,7 @@ def main(example_enable, sglang_engine_enable, mem_fraction_static, torch_compil
if mem_fraction_static is not None: if mem_fraction_static is not None:
model_params["mem_fraction_static"] = mem_fraction_static model_params["mem_fraction_static"] = mem_fraction_static
predictor = modelsingleton.get_model( predictor = model_singleton.get_model(
"sglang-engine", "sglang-engine",
None, None,
None, None,
...@@ -266,14 +266,16 @@ def main(example_enable, sglang_engine_enable, mem_fraction_static, torch_compil ...@@ -266,14 +266,16 @@ def main(example_enable, sglang_engine_enable, mem_fraction_static, torch_compil
drop_list = ["pipeline", "vlm-transformers", "vlm-sglang-client"] drop_list = ["pipeline", "vlm-transformers", "vlm-sglang-client"]
preferred_option = "pipeline" preferred_option = "pipeline"
backend = gr.Dropdown(drop_list, label="Backend", value=preferred_option) backend = gr.Dropdown(drop_list, label="Backend", value=preferred_option)
with gr.Row(visible=False) as ocr_options:
language = gr.Dropdown(all_lang, label='Language', value='ch')
with gr.Row(visible=False) as client_options: with gr.Row(visible=False) as client_options:
url = gr.Textbox(label='Server URL', value='http://localhost:30000', placeholder='http://localhost:30000') url = gr.Textbox(label='Server URL', value='http://localhost:30000', placeholder='http://localhost:30000')
with gr.Row(visible=False) as pipeline_options: with gr.Row(equal_height=True):
is_ocr = gr.Checkbox(label='Force enable OCR', value=False) with gr.Column():
formula_enable = gr.Checkbox(label='Enable formula recognition', value=True) gr.Markdown("**Recognition Options:**")
table_enable = gr.Checkbox(label='Enable table recognition(test)', value=True) formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
table_enable = gr.Checkbox(label='Enable table recognition', value=True)
with gr.Column(visible=False) as ocr_options:
language = gr.Dropdown(all_lang, label='Language', value='ch')
is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
with gr.Row(): with gr.Row():
change_bu = gr.Button('Convert') change_bu = gr.Button('Convert')
clear_bu = gr.ClearButton(value='Clear') clear_bu = gr.ClearButton(value='Clear')
...@@ -302,14 +304,14 @@ def main(example_enable, sglang_engine_enable, mem_fraction_static, torch_compil ...@@ -302,14 +304,14 @@ def main(example_enable, sglang_engine_enable, mem_fraction_static, torch_compil
backend.change( backend.change(
fn=update_interface, fn=update_interface,
inputs=[backend], inputs=[backend],
outputs=[client_options, ocr_options, pipeline_options], outputs=[client_options, ocr_options],
api_name=False api_name=False
) )
# 添加demo.load事件,在页面加载时触发一次界面更新 # 添加demo.load事件,在页面加载时触发一次界面更新
demo.load( demo.load(
fn=update_interface, fn=update_interface,
inputs=[backend], inputs=[backend],
outputs=[client_options, ocr_options, pipeline_options], outputs=[client_options, ocr_options],
api_name=False api_name=False
) )
clear_bu.add([input_file, md, pdf_show, md_text, output_file, is_ocr]) clear_bu.add([input_file, md, pdf_show, md_text, output_file, is_ocr])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment