Unverified Commit 66e616bd authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2895 from opendatalab/release-2.1.0

Release 2.1.0
parents 592b659e a4c9a07b
from mineru.utils.config_reader import get_latex_delimiter_config import os
from mineru.utils.config_reader import get_latex_delimiter_config, get_formula_enable, get_table_enable
from mineru.utils.enum_class import MakeMode, BlockType, ContentType from mineru.utils.enum_class import MakeMode, BlockType, ContentType
...@@ -16,7 +18,7 @@ display_right_delimiter = delimiters['display']['right'] ...@@ -16,7 +18,7 @@ display_right_delimiter = delimiters['display']['right']
inline_left_delimiter = delimiters['inline']['left'] inline_left_delimiter = delimiters['inline']['left']
inline_right_delimiter = delimiters['inline']['right'] inline_right_delimiter = delimiters['inline']['right']
def merge_para_with_text(para_block): def merge_para_with_text(para_block, formula_enable=True, img_buket_path=''):
para_text = '' para_text = ''
for line in para_block['lines']: for line in para_block['lines']:
for j, span in enumerate(line['spans']): for j, span in enumerate(line['spans']):
...@@ -27,7 +29,11 @@ def merge_para_with_text(para_block): ...@@ -27,7 +29,11 @@ def merge_para_with_text(para_block):
elif span_type == ContentType.INLINE_EQUATION: elif span_type == ContentType.INLINE_EQUATION:
content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}" content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}"
elif span_type == ContentType.INTERLINE_EQUATION: elif span_type == ContentType.INTERLINE_EQUATION:
content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n" if formula_enable:
content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
else:
if span.get('image_path', ''):
content = f"![]({img_buket_path}/{span['image_path']})"
# content = content.strip() # content = content.strip()
if content: if content:
if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]: if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
...@@ -39,13 +45,13 @@ def merge_para_with_text(para_block): ...@@ -39,13 +45,13 @@ def merge_para_with_text(para_block):
para_text += content para_text += content
return para_text return para_text
def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''): def mk_blocks_to_markdown(para_blocks, make_mode, formula_enable, table_enable, img_buket_path=''):
page_markdown = [] page_markdown = []
for para_block in para_blocks: for para_block in para_blocks:
para_text = '' para_text = ''
para_type = para_block['type'] para_type = para_block['type']
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.INTERLINE_EQUATION]: if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.INTERLINE_EQUATION]:
para_text = merge_para_with_text(para_block) para_text = merge_para_with_text(para_block, formula_enable=formula_enable, img_buket_path=img_buket_path)
elif para_type == BlockType.TITLE: elif para_type == BlockType.TITLE:
title_level = get_title_level(para_block) title_level = get_title_level(para_block)
para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}' para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
...@@ -95,10 +101,14 @@ def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''): ...@@ -95,10 +101,14 @@ def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''):
for span in line['spans']: for span in line['spans']:
if span['type'] == ContentType.TABLE: if span['type'] == ContentType.TABLE:
# if processed by table model # if processed by table model
if span.get('html', ''): if table_enable:
para_text += f"\n{span['html']}\n" if span.get('html', ''):
elif span.get('image_path', ''): para_text += f"\n{span['html']}\n"
para_text += f"![]({img_buket_path}/{span['image_path']})" elif span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
else:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼table_footnote for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TABLE_FOOTNOTE: if block['type'] == BlockType.TABLE_FOOTNOTE:
para_text += '\n' + merge_para_with_text(block) + ' ' para_text += '\n' + merge_para_with_text(block) + ' '
...@@ -120,25 +130,25 @@ def make_blocks_to_content_list(para_block, img_buket_path, page_idx): ...@@ -120,25 +130,25 @@ def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
para_content = {} para_content = {}
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]: if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
para_content = { para_content = {
'type': 'text', 'type': ContentType.TEXT,
'text': merge_para_with_text(para_block), 'text': merge_para_with_text(para_block),
} }
elif para_type == BlockType.TITLE: elif para_type == BlockType.TITLE:
title_level = get_title_level(para_block) title_level = get_title_level(para_block)
para_content = { para_content = {
'type': 'text', 'type': ContentType.TEXT,
'text': merge_para_with_text(para_block), 'text': merge_para_with_text(para_block),
} }
if title_level != 0: if title_level != 0:
para_content['text_level'] = title_level para_content['text_level'] = title_level
elif para_type == BlockType.INTERLINE_EQUATION: elif para_type == BlockType.INTERLINE_EQUATION:
para_content = { para_content = {
'type': 'equation', 'type': ContentType.EQUATION,
'text': merge_para_with_text(para_block), 'text': merge_para_with_text(para_block),
'text_format': 'latex', 'text_format': 'latex',
} }
elif para_type == BlockType.IMAGE: elif para_type == BlockType.IMAGE:
para_content = {'type': 'image', 'img_path': '', 'img_caption': [], 'img_footnote': []} para_content = {'type': ContentType.IMAGE, 'img_path': '', BlockType.IMAGE_CAPTION: [], BlockType.IMAGE_FOOTNOTE: []}
for block in para_block['blocks']: for block in para_block['blocks']:
if block['type'] == BlockType.IMAGE_BODY: if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']: for line in block['lines']:
...@@ -147,11 +157,11 @@ def make_blocks_to_content_list(para_block, img_buket_path, page_idx): ...@@ -147,11 +157,11 @@ def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
if span.get('image_path', ''): if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}" para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.IMAGE_CAPTION: if block['type'] == BlockType.IMAGE_CAPTION:
para_content['img_caption'].append(merge_para_with_text(block)) para_content[BlockType.IMAGE_CAPTION].append(merge_para_with_text(block))
if block['type'] == BlockType.IMAGE_FOOTNOTE: if block['type'] == BlockType.IMAGE_FOOTNOTE:
para_content['img_footnote'].append(merge_para_with_text(block)) para_content[BlockType.IMAGE_FOOTNOTE].append(merge_para_with_text(block))
elif para_type == BlockType.TABLE: elif para_type == BlockType.TABLE:
para_content = {'type': 'table', 'img_path': '', 'table_caption': [], 'table_footnote': []} para_content = {'type': ContentType.TABLE, 'img_path': '', BlockType.TABLE_CAPTION: [], BlockType.TABLE_FOOTNOTE: []}
for block in para_block['blocks']: for block in para_block['blocks']:
if block['type'] == BlockType.TABLE_BODY: if block['type'] == BlockType.TABLE_BODY:
for line in block['lines']: for line in block['lines']:
...@@ -159,15 +169,15 @@ def make_blocks_to_content_list(para_block, img_buket_path, page_idx): ...@@ -159,15 +169,15 @@ def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
if span['type'] == ContentType.TABLE: if span['type'] == ContentType.TABLE:
if span.get('html', ''): if span.get('html', ''):
para_content['table_body'] = f"{span['html']}" para_content[BlockType.TABLE_BODY] = f"{span['html']}"
if span.get('image_path', ''): if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}" para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.TABLE_CAPTION: if block['type'] == BlockType.TABLE_CAPTION:
para_content['table_caption'].append(merge_para_with_text(block)) para_content[BlockType.TABLE_CAPTION].append(merge_para_with_text(block))
if block['type'] == BlockType.TABLE_FOOTNOTE: if block['type'] == BlockType.TABLE_FOOTNOTE:
para_content['table_footnote'].append(merge_para_with_text(block)) para_content[BlockType.TABLE_FOOTNOTE].append(merge_para_with_text(block))
para_content['page_idx'] = page_idx para_content['page_idx'] = page_idx
...@@ -177,6 +187,10 @@ def union_make(pdf_info_dict: list, ...@@ -177,6 +187,10 @@ def union_make(pdf_info_dict: list,
make_mode: str, make_mode: str,
img_buket_path: str = '', img_buket_path: str = '',
): ):
formula_enable = get_formula_enable(os.getenv('MINERU_VLM_FORMULA_ENABLE', 'True').lower() == 'true')
table_enable = get_table_enable(os.getenv('MINERU_VLM_TABLE_ENABLE', 'True').lower() == 'true')
output_content = [] output_content = []
for page_info in pdf_info_dict: for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks') paras_of_layout = page_info.get('para_blocks')
...@@ -184,7 +198,7 @@ def union_make(pdf_info_dict: list, ...@@ -184,7 +198,7 @@ def union_make(pdf_info_dict: list,
if not paras_of_layout: if not paras_of_layout:
continue continue
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]: if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
page_markdown = mk_blocks_to_markdown(paras_of_layout, make_mode, img_buket_path) page_markdown = mk_blocks_to_markdown(paras_of_layout, make_mode, formula_enable, table_enable, img_buket_path)
output_content.extend(page_markdown) output_content.extend(page_markdown)
elif make_mode == MakeMode.CONTENT_LIST: elif make_mode == MakeMode.CONTENT_LIST:
for para_block in paras_of_layout: for para_block in paras_of_layout:
......
...@@ -4,12 +4,14 @@ import click ...@@ -4,12 +4,14 @@ import click
from pathlib import Path from pathlib import Path
from loguru import logger from loguru import logger
from mineru.utils.cli_parser import arg_parse
from mineru.utils.config_reader import get_device from mineru.utils.config_reader import get_device
from mineru.utils.model_utils import get_vram from mineru.utils.model_utils import get_vram
from ..version import __version__ from ..version import __version__
from .common import do_parse, read_fn, pdf_suffixes, image_suffixes from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
@click.command() @click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.pass_context
@click.version_option(__version__, @click.version_option(__version__,
'--version', '--version',
'-v', '-v',
...@@ -60,7 +62,8 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes ...@@ -60,7 +62,8 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
'-l', '-l',
'--lang', '--lang',
'lang', 'lang',
type=click.Choice(['ch', 'ch_server', 'ch_lite', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']), type=click.Choice(['ch', 'ch_server', 'ch_lite', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka',
'latin', 'arabic', 'east_slavic', 'cyrillic', 'devanagari']),
help=""" help="""
Input the languages in the pdf (if known) to improve OCR accuracy. Optional. Input the languages in the pdf (if known) to improve OCR accuracy. Optional.
Without languages specified, 'ch' will be used by default. Without languages specified, 'ch' will be used by default.
...@@ -136,7 +139,14 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes ...@@ -136,7 +139,14 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
) )
def main(input_path, output_dir, method, backend, lang, server_url, start_page_id, end_page_id, formula_enable, table_enable, device_mode, virtual_vram, model_source): def main(
ctx,
input_path, output_dir, method, backend, lang, server_url,
start_page_id, end_page_id, formula_enable, table_enable,
device_mode, virtual_vram, model_source, **kwargs
):
kwargs.update(arg_parse(ctx))
if not backend.endswith('-client'): if not backend.endswith('-client'):
def get_device_mode() -> str: def get_device_mode() -> str:
...@@ -179,11 +189,12 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i ...@@ -179,11 +189,12 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i
p_lang_list=lang_list, p_lang_list=lang_list,
backend=backend, backend=backend,
parse_method=method, parse_method=method,
p_formula_enable=formula_enable, formula_enable=formula_enable,
p_table_enable=table_enable, table_enable=table_enable,
server_url=server_url, server_url=server_url,
start_page_id=start_page_id, start_page_id=start_page_id,
end_page_id=end_page_id end_page_id=end_page_id,
**kwargs,
) )
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
......
...@@ -14,9 +14,10 @@ from mineru.utils.enum_class import MakeMode ...@@ -14,9 +14,10 @@ from mineru.utils.enum_class import MakeMode
from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes
from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make
from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
from mineru.backend.vlm.vlm_analyze import aio_doc_analyze as aio_vlm_doc_analyze
pdf_suffixes = [".pdf"] pdf_suffixes = [".pdf"]
image_suffixes = [".png", ".jpeg", ".jpg"] image_suffixes = [".png", ".jpeg", ".jpg", ".webp", ".gif"]
def read_fn(path): def read_fn(path):
...@@ -73,155 +74,318 @@ def convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id=0, end_page ...@@ -73,155 +74,318 @@ def convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id=0, end_page
return output_bytes return output_bytes
def do_parse( def _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id):
output_dir, """准备处理PDF字节数据"""
pdf_file_names: list[str], result = []
pdf_bytes_list: list[bytes], for pdf_bytes in pdf_bytes_list:
p_lang_list: list[str], new_pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
backend="pipeline", result.append(new_pdf_bytes)
parse_method="auto", return result
p_formula_enable=True,
p_table_enable=True,
server_url=None, def _process_output(
f_draw_layout_bbox=True, pdf_info,
f_draw_span_bbox=True, pdf_bytes,
f_dump_md=True, pdf_file_name,
f_dump_middle_json=True, local_md_dir,
f_dump_model_output=True, local_image_dir,
f_dump_orig_pdf=True, md_writer,
f_dump_content_list=True, f_draw_layout_bbox,
f_make_md_mode=MakeMode.MM_MD, f_draw_span_bbox,
start_page_id=0, f_dump_orig_pdf,
end_page_id=None, f_dump_md,
f_dump_content_list,
f_dump_middle_json,
f_dump_model_output,
f_make_md_mode,
middle_json,
model_output=None,
is_pipeline=True
): ):
from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make
"""处理输出文件"""
if f_draw_layout_bbox:
draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
if f_draw_span_bbox:
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
if f_dump_orig_pdf:
md_writer.write(
f"{pdf_file_name}_origin.pdf",
pdf_bytes,
)
image_dir = str(os.path.basename(local_image_dir))
if f_dump_md:
make_func = pipeline_union_make if is_pipeline else vlm_union_make
md_content_str = make_func(pdf_info, f_make_md_mode, image_dir)
md_writer.write_string(
f"{pdf_file_name}.md",
md_content_str,
)
if f_dump_content_list:
make_func = pipeline_union_make if is_pipeline else vlm_union_make
content_list = make_func(pdf_info, MakeMode.CONTENT_LIST, image_dir)
md_writer.write_string(
f"{pdf_file_name}_content_list.json",
json.dumps(content_list, ensure_ascii=False, indent=4),
)
if f_dump_middle_json:
md_writer.write_string(
f"{pdf_file_name}_middle.json",
json.dumps(middle_json, ensure_ascii=False, indent=4),
)
if f_dump_model_output:
if is_pipeline:
md_writer.write_string(
f"{pdf_file_name}_model.json",
json.dumps(model_output, ensure_ascii=False, indent=4),
)
else:
output_text = ("\n" + "-" * 50 + "\n").join(model_output)
md_writer.write_string(
f"{pdf_file_name}_model_output.txt",
output_text,
)
logger.info(f"local output dir is {local_md_dir}")
def _process_pipeline(
output_dir,
pdf_file_names,
pdf_bytes_list,
p_lang_list,
parse_method,
p_formula_enable,
p_table_enable,
f_draw_layout_bbox,
f_draw_span_bbox,
f_dump_md,
f_dump_middle_json,
f_dump_model_output,
f_dump_orig_pdf,
f_dump_content_list,
f_make_md_mode,
):
"""处理pipeline后端逻辑"""
from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json
from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list = (
pipeline_doc_analyze(
pdf_bytes_list, p_lang_list, parse_method=parse_method,
formula_enable=p_formula_enable, table_enable=p_table_enable
)
)
for idx, model_list in enumerate(infer_results):
model_json = copy.deepcopy(model_list)
pdf_file_name = pdf_file_names[idx]
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
images_list = all_image_lists[idx]
pdf_doc = all_pdf_docs[idx]
_lang = lang_list[idx]
_ocr_enable = ocr_enabled_list[idx]
middle_json = pipeline_result_to_middle_json(
model_list, images_list, pdf_doc, image_writer,
_lang, _ocr_enable, p_formula_enable
)
pdf_info = middle_json["pdf_info"]
pdf_bytes = pdf_bytes_list[idx]
_process_output(
pdf_info, pdf_bytes, pdf_file_name, local_md_dir, local_image_dir,
md_writer, f_draw_layout_bbox, f_draw_span_bbox, f_dump_orig_pdf,
f_dump_md, f_dump_content_list, f_dump_middle_json, f_dump_model_output,
f_make_md_mode, middle_json, model_json, is_pipeline=True
)
async def _async_process_vlm(
output_dir,
pdf_file_names,
pdf_bytes_list,
backend,
f_draw_layout_bbox,
f_draw_span_bbox,
f_dump_md,
f_dump_middle_json,
f_dump_model_output,
f_dump_orig_pdf,
f_dump_content_list,
f_make_md_mode,
server_url=None,
**kwargs,
):
"""异步处理VLM后端逻辑"""
parse_method = "vlm"
f_draw_span_bbox = False
if not backend.endswith("client"):
server_url = None
for idx, pdf_bytes in enumerate(pdf_bytes_list):
pdf_file_name = pdf_file_names[idx]
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
middle_json, infer_result = await aio_vlm_doc_analyze(
pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url, **kwargs,
)
pdf_info = middle_json["pdf_info"]
_process_output(
pdf_info, pdf_bytes, pdf_file_name, local_md_dir, local_image_dir,
md_writer, f_draw_layout_bbox, f_draw_span_bbox, f_dump_orig_pdf,
f_dump_md, f_dump_content_list, f_dump_middle_json, f_dump_model_output,
f_make_md_mode, middle_json, infer_result, is_pipeline=False
)
def _process_vlm(
output_dir,
pdf_file_names,
pdf_bytes_list,
backend,
f_draw_layout_bbox,
f_draw_span_bbox,
f_dump_md,
f_dump_middle_json,
f_dump_model_output,
f_dump_orig_pdf,
f_dump_content_list,
f_make_md_mode,
server_url=None,
**kwargs,
):
"""同步处理VLM后端逻辑"""
parse_method = "vlm"
f_draw_span_bbox = False
if not backend.endswith("client"):
server_url = None
if backend == "pipeline": for idx, pdf_bytes in enumerate(pdf_bytes_list):
pdf_file_name = pdf_file_names[idx]
from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
for idx, pdf_bytes in enumerate(pdf_bytes_list):
new_pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
pdf_bytes_list[idx] = new_pdf_bytes
infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list = pipeline_doc_analyze(pdf_bytes_list, p_lang_list, parse_method=parse_method, formula_enable=p_formula_enable,table_enable=p_table_enable)
for idx, model_list in enumerate(infer_results):
model_json = copy.deepcopy(model_list)
pdf_file_name = pdf_file_names[idx]
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
images_list = all_image_lists[idx]
pdf_doc = all_pdf_docs[idx]
_lang = lang_list[idx]
_ocr_enable = ocr_enabled_list[idx]
middle_json = pipeline_result_to_middle_json(model_list, images_list, pdf_doc, image_writer, _lang, _ocr_enable, p_formula_enable)
pdf_info = middle_json["pdf_info"]
pdf_bytes = pdf_bytes_list[idx]
if f_draw_layout_bbox:
draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
if f_draw_span_bbox:
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
if f_dump_orig_pdf: middle_json, infer_result = vlm_doc_analyze(
md_writer.write( pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url, **kwargs,
f"{pdf_file_name}_origin.pdf", )
pdf_bytes,
)
if f_dump_md: pdf_info = middle_json["pdf_info"]
image_dir = str(os.path.basename(local_image_dir))
md_content_str = pipeline_union_make(pdf_info, f_make_md_mode, image_dir)
md_writer.write_string(
f"{pdf_file_name}.md",
md_content_str,
)
if f_dump_content_list: _process_output(
image_dir = str(os.path.basename(local_image_dir)) pdf_info, pdf_bytes, pdf_file_name, local_md_dir, local_image_dir,
content_list = pipeline_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir) md_writer, f_draw_layout_bbox, f_draw_span_bbox, f_dump_orig_pdf,
md_writer.write_string( f_dump_md, f_dump_content_list, f_dump_middle_json, f_dump_model_output,
f"{pdf_file_name}_content_list.json", f_make_md_mode, middle_json, infer_result, is_pipeline=False
json.dumps(content_list, ensure_ascii=False, indent=4), )
)
if f_dump_middle_json:
md_writer.write_string(
f"{pdf_file_name}_middle.json",
json.dumps(middle_json, ensure_ascii=False, indent=4),
)
if f_dump_model_output: def do_parse(
md_writer.write_string( output_dir,
f"{pdf_file_name}_model.json", pdf_file_names: list[str],
json.dumps(model_json, ensure_ascii=False, indent=4), pdf_bytes_list: list[bytes],
) p_lang_list: list[str],
backend="pipeline",
parse_method="auto",
formula_enable=True,
table_enable=True,
server_url=None,
f_draw_layout_bbox=True,
f_draw_span_bbox=True,
f_dump_md=True,
f_dump_middle_json=True,
f_dump_model_output=True,
f_dump_orig_pdf=True,
f_dump_content_list=True,
f_make_md_mode=MakeMode.MM_MD,
start_page_id=0,
end_page_id=None,
**kwargs,
):
# 预处理PDF字节数据
pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id)
logger.info(f"local output dir is {local_md_dir}") if backend == "pipeline":
_process_pipeline(
output_dir, pdf_file_names, pdf_bytes_list, p_lang_list,
parse_method, formula_enable, table_enable,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode
)
else: else:
if backend.startswith("vlm-"): if backend.startswith("vlm-"):
backend = backend[4:] backend = backend[4:]
f_draw_span_bbox = False os.environ['MINERU_VLM_FORMULA_ENABLE'] = str(formula_enable)
parse_method = "vlm" os.environ['MINERU_VLM_TABLE_ENABLE'] = str(table_enable)
for idx, pdf_bytes in enumerate(pdf_bytes_list):
pdf_file_name = pdf_file_names[idx] _process_vlm(
pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id) output_dir, pdf_file_names, pdf_bytes_list, backend,
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method) f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir) f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode,
middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url) server_url, **kwargs,
)
pdf_info = middle_json["pdf_info"]
if f_draw_layout_bbox: async def aio_do_parse(
draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf") output_dir,
pdf_file_names: list[str],
if f_draw_span_bbox: pdf_bytes_list: list[bytes],
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf") p_lang_list: list[str],
backend="pipeline",
if f_dump_orig_pdf: parse_method="auto",
md_writer.write( formula_enable=True,
f"{pdf_file_name}_origin.pdf", table_enable=True,
pdf_bytes, server_url=None,
) f_draw_layout_bbox=True,
f_draw_span_bbox=True,
if f_dump_md: f_dump_md=True,
image_dir = str(os.path.basename(local_image_dir)) f_dump_middle_json=True,
md_content_str = vlm_union_make(pdf_info, f_make_md_mode, image_dir) f_dump_model_output=True,
md_writer.write_string( f_dump_orig_pdf=True,
f"{pdf_file_name}.md", f_dump_content_list=True,
md_content_str, f_make_md_mode=MakeMode.MM_MD,
) start_page_id=0,
end_page_id=None,
if f_dump_content_list: **kwargs,
image_dir = str(os.path.basename(local_image_dir)) ):
content_list = vlm_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir) # 预处理PDF字节数据
md_writer.write_string( pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id)
f"{pdf_file_name}_content_list.json",
json.dumps(content_list, ensure_ascii=False, indent=4),
)
if f_dump_middle_json: if backend == "pipeline":
md_writer.write_string( # pipeline模式暂不支持异步,使用同步处理方式
f"{pdf_file_name}_middle.json", _process_pipeline(
json.dumps(middle_json, ensure_ascii=False, indent=4), output_dir, pdf_file_names, pdf_bytes_list, p_lang_list,
) parse_method, formula_enable, table_enable,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode
)
else:
if backend.startswith("vlm-"):
backend = backend[4:]
if f_dump_model_output: os.environ['MINERU_VLM_FORMULA_ENABLE'] = str(formula_enable)
model_output = ("\n" + "-" * 50 + "\n").join(infer_result) os.environ['MINERU_VLM_TABLE_ENABLE'] = str(table_enable)
md_writer.write_string(
f"{pdf_file_name}_model_output.txt",
model_output,
)
logger.info(f"local output dir is {local_md_dir}") await _async_process_vlm(
output_dir, pdf_file_names, pdf_bytes_list, backend,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode,
server_url, **kwargs,
)
......
import uuid
import os
import uvicorn
import click
from pathlib import Path
from glob import glob
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import JSONResponse
from typing import List, Optional
from loguru import logger
from base64 import b64encode
from mineru.cli.common import aio_do_parse, read_fn, pdf_suffixes, image_suffixes
from mineru.utils.cli_parser import arg_parse
from mineru.version import __version__
app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=1000)
def encode_image(image_path: str) -> str:
"""Encode image using base64"""
with open(image_path, "rb") as f:
return b64encode(f.read()).decode()
def get_infer_result(file_suffix_identifier: str, pdf_name: str, parse_dir: str) -> Optional[str]:
"""从结果文件中读取推理结果"""
result_file_path = os.path.join(parse_dir, f"{pdf_name}{file_suffix_identifier}")
if os.path.exists(result_file_path):
with open(result_file_path, "r", encoding="utf-8") as fp:
return fp.read()
return None
@app.post(path="/file_parse",)
async def parse_pdf(
files: List[UploadFile] = File(...),
output_dir: str = Form("./output"),
lang_list: List[str] = Form(["ch"]),
backend: str = Form("pipeline"),
parse_method: str = Form("auto"),
formula_enable: bool = Form(True),
table_enable: bool = Form(True),
server_url: Optional[str] = Form(None),
return_md: bool = Form(True),
return_middle_json: bool = Form(False),
return_model_output: bool = Form(False),
return_content_list: bool = Form(False),
return_images: bool = Form(False),
start_page_id: int = Form(0),
end_page_id: int = Form(99999),
):
# 获取命令行配置参数
config = getattr(app.state, "config", {})
try:
# 创建唯一的输出目录
unique_dir = os.path.join(output_dir, str(uuid.uuid4()))
os.makedirs(unique_dir, exist_ok=True)
# 处理上传的PDF文件
pdf_file_names = []
pdf_bytes_list = []
for file in files:
content = await file.read()
file_path = Path(file.filename)
# 如果是图像文件或PDF,使用read_fn处理
if file_path.suffix.lower() in pdf_suffixes + image_suffixes:
# 创建临时文件以便使用read_fn
temp_path = Path(unique_dir) / file_path.name
with open(temp_path, "wb") as f:
f.write(content)
try:
pdf_bytes = read_fn(temp_path)
pdf_bytes_list.append(pdf_bytes)
pdf_file_names.append(file_path.stem)
os.remove(temp_path) # 删除临时文件
except Exception as e:
return JSONResponse(
status_code=400,
content={"error": f"Failed to load file: {str(e)}"}
)
else:
return JSONResponse(
status_code=400,
content={"error": f"Unsupported file type: {file_path.suffix}"}
)
# 设置语言列表,确保与文件数量一致
actual_lang_list = lang_list
if len(actual_lang_list) != len(pdf_file_names):
# 如果语言列表长度不匹配,使用第一个语言或默认"ch"
actual_lang_list = [actual_lang_list[0] if actual_lang_list else "ch"] * len(pdf_file_names)
# 调用异步处理函数
await aio_do_parse(
output_dir=unique_dir,
pdf_file_names=pdf_file_names,
pdf_bytes_list=pdf_bytes_list,
p_lang_list=actual_lang_list,
backend=backend,
parse_method=parse_method,
formula_enable=formula_enable,
table_enable=table_enable,
server_url=server_url,
f_draw_layout_bbox=False,
f_draw_span_bbox=False,
f_dump_md=return_md,
f_dump_middle_json=return_middle_json,
f_dump_model_output=return_model_output,
f_dump_orig_pdf=False,
f_dump_content_list=return_content_list,
start_page_id=start_page_id,
end_page_id=end_page_id,
**config
)
# 构建结果路径
result_dict = {}
for pdf_name in pdf_file_names:
result_dict[pdf_name] = {}
data = result_dict[pdf_name]
if backend.startswith("pipeline"):
parse_dir = os.path.join(unique_dir, pdf_name, parse_method)
else:
parse_dir = os.path.join(unique_dir, pdf_name, "vlm")
if os.path.exists(parse_dir):
if return_md:
data["md_content"] = get_infer_result(".md", pdf_name, parse_dir)
if return_middle_json:
data["middle_json"] = get_infer_result("_middle.json", pdf_name, parse_dir)
if return_model_output:
if backend.startswith("pipeline"):
data["model_output"] = get_infer_result("_model.json", pdf_name, parse_dir)
else:
data["model_output"] = get_infer_result("_model_output.txt", pdf_name, parse_dir)
if return_content_list:
data["content_list"] = get_infer_result("_content_list.json", pdf_name, parse_dir)
if return_images:
image_paths = glob(f"{parse_dir}/images/*.jpg")
data["images"] = {
os.path.basename(
image_path
): f"data:image/jpeg;base64,{encode_image(image_path)}"
for image_path in image_paths
}
return JSONResponse(
status_code=200,
content={
"backend": backend,
"version": __version__,
"results": result_dict
}
)
except Exception as e:
logger.exception(e)
return JSONResponse(
status_code=500,
content={"error": f"Failed to process file: {str(e)}"}
)
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.pass_context
@click.option('--host', default='127.0.0.1', help='Server host (default: 127.0.0.1)')
@click.option('--port', default=8000, type=int, help='Server port (default: 8000)')
@click.option('--reload', is_flag=True, help='Enable auto-reload (development mode)')
def main(ctx, host, port, reload, **kwargs):
kwargs.update(arg_parse(ctx))
# 将配置参数存储到应用状态中
app.state.config = kwargs
"""启动MinerU FastAPI服务器的命令行入口"""
print(f"Start MinerU FastAPI Service: http://{host}:{port}")
print("The API documentation can be accessed at the following address:")
print(f"- Swagger UI: http://{host}:{port}/docs")
print(f"- ReDoc: http://{host}:{port}/redoc")
uvicorn.run(
"mineru.cli.fast_api:app",
host=host,
port=port,
reload=reload
)
if __name__ == "__main__":
main()
\ No newline at end of file
...@@ -7,38 +7,47 @@ import time ...@@ -7,38 +7,47 @@ import time
import zipfile import zipfile
from pathlib import Path from pathlib import Path
import click
import gradio as gr import gradio as gr
from gradio_pdf import PDF from gradio_pdf import PDF
from loguru import logger from loguru import logger
from mineru.cli.common import prepare_env, do_parse, read_fn from mineru.cli.common import prepare_env, read_fn, aio_do_parse, pdf_suffixes, image_suffixes
from mineru.utils.cli_parser import arg_parse
from mineru.utils.hash_utils import str_sha256 from mineru.utils.hash_utils import str_sha256
def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, formula_enable, table_enable, language): async def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, formula_enable, table_enable, language, backend, url):
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
try: try:
file_name = f'{str(Path(doc_path).stem)}_{time.strftime("%y%m%d_%H%M%S")}' file_name = f'{safe_stem(Path(doc_path).stem)}_{time.strftime("%y%m%d_%H%M%S")}'
pdf_data = read_fn(doc_path) pdf_data = read_fn(doc_path)
if is_ocr: if is_ocr:
parse_method = 'ocr' parse_method = 'ocr'
else: else:
parse_method = 'auto' parse_method = 'auto'
if backend.startswith("vlm"):
parse_method = "vlm"
local_image_dir, local_md_dir = prepare_env(output_dir, file_name, parse_method) local_image_dir, local_md_dir = prepare_env(output_dir, file_name, parse_method)
do_parse( await aio_do_parse(
output_dir=output_dir, output_dir=output_dir,
pdf_file_names=[file_name], pdf_file_names=[file_name],
pdf_bytes_list=[pdf_data], pdf_bytes_list=[pdf_data],
p_lang_list=[language], p_lang_list=[language],
parse_method=parse_method, parse_method=parse_method,
end_page_id=end_page_id, end_page_id=end_page_id,
p_formula_enable=formula_enable, formula_enable=formula_enable,
p_table_enable=table_enable, table_enable=table_enable,
backend=backend,
server_url=url,
) )
return local_md_dir, file_name return local_md_dir, file_name
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
return None
def compress_directory_to_zip(directory_path, output_zip_path): def compress_directory_to_zip(directory_path, output_zip_path):
...@@ -85,16 +94,16 @@ def replace_image_with_base64(markdown_text, image_dir_path): ...@@ -85,16 +94,16 @@ def replace_image_with_base64(markdown_text, image_dir_path):
return re.sub(pattern, replace, markdown_text) return re.sub(pattern, replace, markdown_text)
def to_markdown(file_path, end_pages, is_ocr, formula_enable, table_enable, language): async def to_markdown(file_path, end_pages=10, is_ocr=False, formula_enable=True, table_enable=True, language="ch", backend="pipeline", url=None):
file_path = to_pdf(file_path) file_path = to_pdf(file_path)
# 获取识别的md文件以及压缩包文件路径 # 获取识别的md文件以及压缩包文件路径
local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr, formula_enable, table_enable, language) local_md_dir, file_name = await parse_pdf(file_path, './output', end_pages - 1, is_ocr, formula_enable, table_enable, language, backend, url)
archive_zip_path = os.path.join('./output', str_sha256(local_md_dir) + '.zip') archive_zip_path = os.path.join('./output', str_sha256(local_md_dir) + '.zip')
zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path) zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
if zip_archive_success == 0: if zip_archive_success == 0:
logger.info('压缩成功') logger.info('Compression successful')
else: else:
logger.error('压缩失败') logger.error('Compression failed')
md_path = os.path.join(local_md_dir, file_name + '.md') md_path = os.path.join(local_md_dir, file_name + '.md')
with open(md_path, 'r', encoding='utf-8') as f: with open(md_path, 'r', encoding='utf-8') as f:
txt_content = f.read() txt_content = f.read()
...@@ -112,9 +121,9 @@ latex_delimiters = [ ...@@ -112,9 +121,9 @@ latex_delimiters = [
{'left': '\\[', 'right': '\\]', 'display': True}, {'left': '\\[', 'right': '\\]', 'display': True},
] ]
header_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'resources', 'header.html')
with open('header.html', 'r') as file: with open(header_path, 'r') as header_file:
header = file.read() header = header_file.read()
latin_lang = [ latin_lang = [
...@@ -125,15 +134,16 @@ latin_lang = [ ...@@ -125,15 +134,16 @@ latin_lang = [
] ]
arabic_lang = ['ar', 'fa', 'ug', 'ur'] arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [ cyrillic_lang = [
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126 'rs_cyrillic', 'bg', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126
'dar', 'inh', 'che', 'lbe', 'lez', 'tab' 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
] ]
east_slavic_lang = ["ru", "be", "uk"]
devanagari_lang = [ devanagari_lang = [
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', # noqa: E126 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', # noqa: E126
'sa', 'bgc' 'sa', 'bgc'
] ]
other_lang = ['ch', 'ch_lite', 'ch_server', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka'] other_lang = ['ch', 'ch_lite', 'ch_server', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']
add_lang = ['latin', 'arabic', 'cyrillic', 'devanagari'] add_lang = ['latin', 'arabic', 'east_slavic', 'cyrillic', 'devanagari']
# all_lang = ['', 'auto'] # all_lang = ['', 'auto']
all_lang = [] all_lang = []
...@@ -167,33 +177,125 @@ def to_pdf(file_path): ...@@ -167,33 +177,125 @@ def to_pdf(file_path):
return tmp_file_path return tmp_file_path
if __name__ == '__main__': # 更新界面函数
def update_interface(backend_choice):
if backend_choice in ["vlm-transformers", "vlm-sglang-engine"]:
return gr.update(visible=False), gr.update(visible=False)
elif backend_choice in ["vlm-sglang-client"]:
return gr.update(visible=True), gr.update(visible=False)
elif backend_choice in ["pipeline"]:
return gr.update(visible=False), gr.update(visible=True)
else:
pass
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.pass_context
@click.option(
'--enable-example',
'example_enable',
type=bool,
help="Enable example files for input."
"The example files to be input need to be placed in the `example` folder within the directory where the command is currently executed.",
default=True,
)
@click.option(
'--enable-sglang-engine',
'sglang_engine_enable',
type=bool,
help="Enable SgLang engine backend for faster processing.",
default=False,
)
@click.option(
'--enable-api',
'api_enable',
type=bool,
help="Enable gradio API for serving the application.",
default=True,
)
@click.option(
'--max-convert-pages',
'max_convert_pages',
type=int,
help="Set the maximum number of pages to convert from PDF to Markdown.",
default=1000,
)
@click.option(
'--server-name',
'server_name',
type=str,
help="Set the server name for the Gradio app.",
default=None,
)
@click.option(
'--server-port',
'server_port',
type=int,
help="Set the server port for the Gradio app.",
default=None,
)
def main(ctx,
example_enable, sglang_engine_enable, api_enable, max_convert_pages,
server_name, server_port, **kwargs
):
kwargs.update(arg_parse(ctx))
if sglang_engine_enable:
try:
print("Start init SgLang engine...")
from mineru.backend.vlm.vlm_analyze import ModelSingleton
model_singleton = ModelSingleton()
predictor = model_singleton.get_model(
"sglang-engine",
None,
None,
**kwargs
)
print("SgLang engine init successfully.")
except Exception as e:
logger.exception(e)
suffixes = pdf_suffixes + image_suffixes
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.HTML(header) gr.HTML(header)
with gr.Row(): with gr.Row():
with gr.Column(variant='panel', scale=5): with gr.Column(variant='panel', scale=5):
with gr.Row(): with gr.Row():
file = gr.File(label='Please upload a PDF or image', file_types=['.pdf', '.png', '.jpeg', '.jpg']) input_file = gr.File(label='Please upload a PDF or image', file_types=suffixes)
with gr.Row():
max_pages = gr.Slider(1, max_convert_pages, int(max_convert_pages/2), step=1, label='Max convert pages')
with gr.Row():
if sglang_engine_enable:
drop_list = ["pipeline", "vlm-sglang-engine"]
preferred_option = "vlm-sglang-engine"
else:
drop_list = ["pipeline", "vlm-transformers", "vlm-sglang-client"]
preferred_option = "pipeline"
backend = gr.Dropdown(drop_list, label="Backend", value=preferred_option)
with gr.Row(visible=False) as client_options:
url = gr.Textbox(label='Server URL', value='http://localhost:30000', placeholder='http://localhost:30000')
with gr.Row(equal_height=True): with gr.Row(equal_height=True):
with gr.Column(scale=4): with gr.Column():
max_pages = gr.Slider(1, 20, 10, step=1, label='Max convert pages') gr.Markdown("**Recognition Options:**")
with gr.Column(scale=1): formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
table_enable = gr.Checkbox(label='Enable table recognition', value=True)
with gr.Column(visible=False) as ocr_options:
language = gr.Dropdown(all_lang, label='Language', value='ch') language = gr.Dropdown(all_lang, label='Language', value='ch')
with gr.Row(): is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
table_enable = gr.Checkbox(label='Enable table recognition(test)', value=True)
with gr.Row(): with gr.Row():
change_bu = gr.Button('Convert') change_bu = gr.Button('Convert')
clear_bu = gr.ClearButton(value='Clear') clear_bu = gr.ClearButton(value='Clear')
pdf_show = PDF(label='PDF preview', interactive=False, visible=True, height=800) pdf_show = PDF(label='PDF preview', interactive=False, visible=True, height=800)
with gr.Accordion('Examples:'): if example_enable:
example_root = os.path.join(os.path.dirname(__file__), 'examples') example_root = os.path.join(os.getcwd(), 'examples')
gr.Examples( if os.path.exists(example_root):
examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if with gr.Accordion('Examples:'):
_.endswith('pdf')], gr.Examples(
inputs=file examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
) _.endswith(tuple(suffixes))],
inputs=input_file
)
with gr.Column(variant='panel', scale=5): with gr.Column(variant='panel', scale=5):
output_file = gr.File(label='convert result', interactive=False) output_file = gr.File(label='convert result', interactive=False)
...@@ -204,9 +306,38 @@ if __name__ == '__main__': ...@@ -204,9 +306,38 @@ if __name__ == '__main__':
line_breaks=True) line_breaks=True)
with gr.Tab('Markdown text'): with gr.Tab('Markdown text'):
md_text = gr.TextArea(lines=45, show_copy_button=True) md_text = gr.TextArea(lines=45, show_copy_button=True)
file.change(fn=to_pdf, inputs=file, outputs=pdf_show)
change_bu.click(fn=to_markdown, inputs=[file, max_pages, is_ocr, formula_enable, table_enable, language],
outputs=[md, md_text, output_file, pdf_show])
clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr])
demo.launch(server_name='0.0.0.0') # 添加事件处理
backend.change(
fn=update_interface,
inputs=[backend],
outputs=[client_options, ocr_options],
api_name=False
)
# 添加demo.load事件,在页面加载时触发一次界面更新
demo.load(
fn=update_interface,
inputs=[backend],
outputs=[client_options, ocr_options],
api_name=False
)
clear_bu.add([input_file, md, pdf_show, md_text, output_file, is_ocr])
if api_enable:
api_name = None
else:
api_name = False
input_file.change(fn=to_pdf, inputs=input_file, outputs=pdf_show, api_name=api_name)
change_bu.click(
fn=to_markdown,
inputs=[input_file, max_pages, is_ocr, formula_enable, table_enable, language, backend, url],
outputs=[md, md_text, output_file, pdf_show],
api_name=api_name
)
demo.launch(server_name=server_name, server_port=server_port, show_api=api_enable)
if __name__ == '__main__':
main()
\ No newline at end of file
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import sys import sys
import click import click
import requests import requests
from loguru import logger
from mineru.utils.enum_class import ModelPath from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
...@@ -54,7 +55,32 @@ def configure_model(model_dir, model_type): ...@@ -54,7 +55,32 @@ def configure_model(model_dir, model_type):
} }
download_and_modify_json(json_url, config_file, json_mods) download_and_modify_json(json_url, config_file, json_mods)
print(f'The configuration file has been successfully configured, the path is: {config_file}') logger.info(f'The configuration file has been successfully configured, the path is: {config_file}')
def download_pipeline_models():
"""下载Pipeline模型"""
model_paths = [
ModelPath.doclayout_yolo,
ModelPath.yolo_v8_mfd,
ModelPath.unimernet_small,
ModelPath.pytorch_paddle,
ModelPath.layout_reader,
ModelPath.slanet_plus
]
download_finish_path = ""
for model_path in model_paths:
logger.info(f"Downloading model: {model_path}")
download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode='pipeline')
logger.info(f"Pipeline models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, "pipeline")
def download_vlm_models():
"""下载VLM模型"""
download_finish_path = auto_download_and_get_model_root_path("/", repo_mode='vlm')
logger.info(f"VLM models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, "vlm")
@click.command() @click.command()
...@@ -102,30 +128,7 @@ def download_models(model_source, model_type): ...@@ -102,30 +128,7 @@ def download_models(model_source, model_type):
default='all' default='all'
) )
click.echo(f"Downloading {model_type} model from {os.getenv('MINERU_MODEL_SOURCE', None)}...") logger.info(f"Downloading {model_type} model from {os.getenv('MINERU_MODEL_SOURCE', None)}...")
def download_pipeline_models():
"""下载Pipeline模型"""
model_paths = [
ModelPath.doclayout_yolo,
ModelPath.yolo_v8_mfd,
ModelPath.unimernet_small,
ModelPath.pytorch_paddle,
ModelPath.layout_reader,
ModelPath.slanet_plus
]
download_finish_path = ""
for model_path in model_paths:
click.echo(f"Downloading model: {model_path}")
download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode='pipeline')
click.echo(f"Pipeline models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, "pipeline")
def download_vlm_models():
"""下载VLM模型"""
download_finish_path = auto_download_and_get_model_root_path("/", repo_mode='vlm')
click.echo(f"VLM models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, "vlm")
try: try:
if model_type == 'pipeline': if model_type == 'pipeline':
...@@ -140,7 +143,7 @@ def download_models(model_source, model_type): ...@@ -140,7 +143,7 @@ def download_models(model_source, model_type):
sys.exit(1) sys.exit(1)
except Exception as e: except Exception as e:
click.echo(f"Download failed: {str(e)}", err=True) logger.exception(f"An error occurred while downloading models: {str(e)}")
sys.exit(1) sys.exit(1)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -26,9 +26,10 @@ latin_lang = [ ...@@ -26,9 +26,10 @@ latin_lang = [
] ]
arabic_lang = ['ar', 'fa', 'ug', 'ur'] arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [ cyrillic_lang = [
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126 'rs_cyrillic', 'bg', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126
'dar', 'inh', 'che', 'lbe', 'lez', 'tab' 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
] ]
east_slavic_lang = ["ru", "be", "uk"]
devanagari_lang = [ devanagari_lang = [
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', # noqa: E126 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', # noqa: E126
'sa', 'bgc' 'sa', 'bgc'
...@@ -69,6 +70,8 @@ class PytorchPaddleOCR(TextSystem): ...@@ -69,6 +70,8 @@ class PytorchPaddleOCR(TextSystem):
self.lang = 'cyrillic' self.lang = 'cyrillic'
elif self.lang in devanagari_lang: elif self.lang in devanagari_lang:
self.lang = 'devanagari' self.lang = 'devanagari'
elif self.lang in east_slavic_lang:
self.lang = 'east_slavic'
else: else:
pass pass
......
...@@ -490,3 +490,82 @@ devanagari_PP-OCRv3_rec_infer: ...@@ -490,3 +490,82 @@ devanagari_PP-OCRv3_rec_infer:
# out_channels: 169 # out_channels: 169
fc_decay: 0.00001 fc_decay: 0.00001
korean_PP-OCRv5_rec_infer:
model_type: rec
algorithm: SVTR_HGNet
Transform:
Backbone:
name: PPLCNetV3
scale: 0.95
Head:
name: MultiHead
out_channels_list:
CTCLabelDecode: 11947
head_list:
- CTCHead:
Neck:
name: svtr
dims: 120
depth: 2
hidden_dims: 120
kernel_size: [ 1, 3 ]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
max_text_length: 25
latin_PP-OCRv5_rec_infer:
model_type: rec
algorithm: SVTR_HGNet
Transform:
Backbone:
name: PPLCNetV3
scale: 0.95
Head:
name: MultiHead
out_channels_list:
CTCLabelDecode: 504
head_list:
- CTCHead:
Neck:
name: svtr
dims: 120
depth: 2
hidden_dims: 120
kernel_size: [ 1, 3 ]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
max_text_length: 25
eslav_PP-OCRv5_rec_infer:
model_type: rec
algorithm: SVTR_HGNet
Transform:
Backbone:
name: PPLCNetV3
scale: 0.95
Head:
name: MultiHead
out_channels_list:
CTCLabelDecode: 519
head_list:
- CTCHead:
Neck:
name: svtr
dims: 120
depth: 2
hidden_dims: 120
kernel_size: [ 1, 3 ]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
max_text_length: 25
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
]
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
©
{
}
\
|
@
^
~
÷
·
±
®
Ω
¢
£
¥
𝑢
𝜓
ƒ
À
Á
Â
Ã
Ä
Å
Æ
Ç
È
É
Ê
Ë
Ì
Í
Î
Ï
Ð
Ñ
Ò
Ó
Ô
Õ
Ö
Ø
Ù
Ú
Û
Ü
Ý
Þ
à
á
â
ã
ä
å
æ
ç
è
é
ê
ë
ì
í
î
ï
ð
ñ
ò
ó
ô
õ
ö
ø
ù
ú
û
ü
ý
þ
ÿ
¡
¤
¦
§
¨
ª
«
¬
¯
°
²
³
´
µ
¸
¹
º
»
¼
½
¾
¿
×
Α
α
Β
β
Γ
γ
Δ
δ
Ε
ε
Ζ
ζ
Η
η
Θ
θ
Ι
ι
Κ
κ
Λ
λ
Μ
μ
Ν
ν
Ξ
ξ
Ο
ο
Π
π
Ρ
ρ
Σ
σ
ς
Τ
τ
Υ
υ
Φ
φ
Χ
χ
Ψ
ψ
ω
А
Б
В
Г
Ґ
Д
Е
Ё
Є
Ж
З
И
І
Ї
Й
К
Л
М
Н
О
П
Р
С
Т
У
Ў
Ф
Х
Ц
Ч
Ш
Щ
Ъ
Ы
Ь
Э
Ю
Я
а
б
в
г
ґ
д
е
ё
є
ж
з
и
і
ї
й
к
л
м
н
о
п
р
с
т
у
ў
ф
х
ц
ч
ш
щ
ъ
ы
ь
э
ю
я
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
\
]
^
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
{
|
}
~
¡
¢
£
¤
¥
¦
§
¨
©
ª
«
¬
­
®
¯
°
±
²
³
´
µ
·
¸
¹
º
»
¼
½
¾
¿
À
Á
Â
Ã
Ä
Å
Æ
Ç
È
É
Ê
Ë
Ì
Í
Î
Ï
Ð
Ñ
Ò
Ó
Ô
Õ
Ö
×
Ø
Ù
Ú
Û
Ü
Ý
Þ
ß
à
á
â
ã
ä
å
æ
ç
è
é
ê
ë
ì
í
î
ï
ð
ñ
ò
ó
ô
õ
ö
÷
ø
ù
ú
û
ü
ý
þ
ÿ
Ą
ą
Ć
ć
Č
č
Ď
ď
Đ
đ
Ė
ė
Ę
ę
Ě
ě
Ğ
ğ
Į
į
İ
ı
Ĺ
ĺ
Ľ
ľ
Ł
ł
Ń
ń
Ň
ň
ō
Ő
ő
Œ
œ
Ŕ
ŕ
Ř
ř
Ś
ś
Ş
ş
Š
š
Ť
ť
Ū
ū
Ů
ů
Ű
ű
Ų
ų
Ÿ
Ź
ź
Ż
ż
Ž
ž
ƒ
ʒ
Ω
α
β
γ
δ
ε
ζ
η
θ
ι
κ
λ
μ
ν
ξ
ο
π
ρ
ς
σ
τ
υ
φ
χ
ψ
ω
з
𝑢
𝜓
...@@ -24,17 +24,17 @@ lang: ...@@ -24,17 +24,17 @@ lang:
rec: en_PP-OCRv4_rec_infer.pth rec: en_PP-OCRv4_rec_infer.pth
dict: en_dict.txt dict: en_dict.txt
korean: korean:
det: Multilingual_PP-OCRv3_det_infer.pth det: ch_PP-OCRv5_det_infer.pth
rec: korean_PP-OCRv3_rec_infer.pth rec: korean_PP-OCRv5_rec_infer.pth
dict: korean_dict.txt dict: ppocrv5_korean_dict.txt
japan: japan:
det: ch_PP-OCRv5_det_infer.pth det: ch_PP-OCRv5_det_infer.pth
rec: ch_PP-OCRv5_rec_server_infer.pth rec: ch_PP-OCRv5_rec_server_infer.pth
dict: japan_dict.txt dict: ppocrv5_dict.txt
chinese_cht: chinese_cht:
det: ch_PP-OCRv5_det_infer.pth det: ch_PP-OCRv5_det_infer.pth
rec: ch_PP-OCRv5_rec_server_infer.pth rec: ch_PP-OCRv5_rec_server_infer.pth
dict: chinese_cht_dict.txt dict: ppocrv5_dict.txt
ta: ta:
det: Multilingual_PP-OCRv3_det_infer.pth det: Multilingual_PP-OCRv3_det_infer.pth
rec: ta_PP-OCRv3_rec_infer.pth rec: ta_PP-OCRv3_rec_infer.pth
...@@ -48,9 +48,9 @@ lang: ...@@ -48,9 +48,9 @@ lang:
rec: ka_PP-OCRv3_rec_infer.pth rec: ka_PP-OCRv3_rec_infer.pth
dict: ka_dict.txt dict: ka_dict.txt
latin: latin:
det: en_PP-OCRv3_det_infer.pth det: ch_PP-OCRv5_det_infer.pth
rec: latin_PP-OCRv3_rec_infer.pth rec: latin_PP-OCRv5_rec_infer.pth
dict: latin_dict.txt dict: ppocrv5_latin_dict.txt
arabic: arabic:
det: Multilingual_PP-OCRv3_det_infer.pth det: Multilingual_PP-OCRv3_det_infer.pth
rec: arabic_PP-OCRv3_rec_infer.pth rec: arabic_PP-OCRv3_rec_infer.pth
...@@ -62,4 +62,8 @@ lang: ...@@ -62,4 +62,8 @@ lang:
devanagari: devanagari:
det: Multilingual_PP-OCRv3_det_infer.pth det: Multilingual_PP-OCRv3_det_infer.pth
rec: devanagari_PP-OCRv3_rec_infer.pth rec: devanagari_PP-OCRv3_rec_infer.pth
dict: devanagari_dict.txt dict: devanagari_dict.txt
\ No newline at end of file east_slavic:
det: ch_PP-OCRv5_det_infer.pth
rec: eslav_PP-OCRv5_rec_infer.pth
dict: ppocrv5_eslav_dict.txt
\ No newline at end of file
...@@ -54,7 +54,7 @@ ...@@ -54,7 +54,7 @@
font-family: 'Trebuchet MS', 'Lucida Sans Unicode', font-family: 'Trebuchet MS', 'Lucida Sans Unicode',
'Lucida Grande', 'Lucida Sans', Arial, sans-serif; 'Lucida Grande', 'Lucida Sans', Arial, sans-serif;
"> ">
MinerU: PDF Extraction Demo MinerU 2: PDF Extraction Demo
</h1> </h1>
</div> </div>
</div> </div>
...@@ -66,8 +66,7 @@ ...@@ -66,8 +66,7 @@
color: #fafafa; color: #fafafa;
opacity: 0.8; opacity: 0.8;
"> ">
A one-stop, open-source, high-quality data extraction tool, supports A one-stop, open-source, high-quality data extraction tool that supports converting PDF to Markdown and JSON.<br>
PDF/webpage/e-book extraction.<br>
</p> </p>
<style> <style>
.link-block { .link-block {
......
...@@ -90,8 +90,8 @@ def prepare_block_bboxes( ...@@ -90,8 +90,8 @@ def prepare_block_bboxes(
"""经过以上处理后,还存在大框套小框的情况,则删除小框""" """经过以上处理后,还存在大框套小框的情况,则删除小框"""
all_bboxes = remove_overlaps_min_blocks(all_bboxes) all_bboxes = remove_overlaps_min_blocks(all_bboxes)
all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks) all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks)
"""将剩余的bbox做分离处理,防止后面分layout时出错"""
# all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes) """粗排序后返回"""
all_bboxes.sort(key=lambda x: x[0]+x[1]) all_bboxes.sort(key=lambda x: x[0]+x[1])
return all_bboxes, all_discarded_blocks, footnote_blocks return all_bboxes, all_discarded_blocks, footnote_blocks
...@@ -213,35 +213,39 @@ def remove_overlaps_min_blocks(all_bboxes): ...@@ -213,35 +213,39 @@ def remove_overlaps_min_blocks(all_bboxes):
# 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。 # 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。
# 删除重叠blocks中较小的那些 # 删除重叠blocks中较小的那些
need_remove = [] need_remove = []
for block1 in all_bboxes: for i in range(len(all_bboxes)):
for block2 in all_bboxes: for j in range(i + 1, len(all_bboxes)):
if block1 != block2: block1 = all_bboxes[i]
block1_bbox = block1[:4] block2 = all_bboxes[j]
block2_bbox = block2[:4] block1_bbox = block1[:4]
overlap_box = get_minbox_if_overlap_by_ratio( block2_bbox = block2[:4]
block1_bbox, block2_bbox, 0.8 overlap_box = get_minbox_if_overlap_by_ratio(
) block1_bbox, block2_bbox, 0.8
if overlap_box is not None: )
block_to_remove = next( if overlap_box is not None:
(block for block in all_bboxes if block[:4] == overlap_box), # 判断哪个区块的面积更小,移除较小的区块
None, area1 = (block1[2] - block1[0]) * (block1[3] - block1[1])
) area2 = (block2[2] - block2[0]) * (block2[3] - block2[1])
if (
block_to_remove is not None if area1 <= area2:
and block_to_remove not in need_remove block_to_remove = block1
): large_block = block2
large_block = block1 if block1 != block_to_remove else block2 else:
x1, y1, x2, y2 = large_block[:4] block_to_remove = block2
sx1, sy1, sx2, sy2 = block_to_remove[:4] large_block = block1
x1 = min(x1, sx1)
y1 = min(y1, sy1) if block_to_remove not in need_remove:
x2 = max(x2, sx2) x1, y1, x2, y2 = large_block[:4]
y2 = max(y2, sy2) sx1, sy1, sx2, sy2 = block_to_remove[:4]
large_block[:4] = [x1, y1, x2, y2] x1 = min(x1, sx1)
need_remove.append(block_to_remove) y1 = min(y1, sy1)
x2 = max(x2, sx2)
if len(need_remove) > 0: y2 = max(y2, sy2)
for block in need_remove: large_block[:4] = [x1, y1, x2, y2]
need_remove.append(block_to_remove)
for block in need_remove:
if block in all_bboxes:
all_bboxes.remove(block) all_bboxes.remove(block)
return all_bboxes return all_bboxes
\ No newline at end of file
import click
def arg_parse(ctx: 'click.Context') -> dict:
# 解析额外参数
extra_kwargs = {}
i = 0
while i < len(ctx.args):
arg = ctx.args[i]
if arg.startswith('--'):
param_name = arg[2:].replace('-', '_') # 转换参数名格式
i += 1
if i < len(ctx.args) and not ctx.args[i].startswith('--'):
# 参数有值
try:
# 尝试转换为适当的类型
if ctx.args[i].lower() == 'true':
extra_kwargs[param_name] = True
elif ctx.args[i].lower() == 'false':
extra_kwargs[param_name] = False
elif '.' in ctx.args[i]:
try:
extra_kwargs[param_name] = float(ctx.args[i])
except ValueError:
extra_kwargs[param_name] = ctx.args[i]
else:
try:
extra_kwargs[param_name] = int(ctx.args[i])
except ValueError:
extra_kwargs[param_name] = ctx.args[i]
except:
extra_kwargs[param_name] = ctx.args[i]
else:
# 布尔型标志参数
extra_kwargs[param_name] = True
i -= 1
i += 1
return extra_kwargs
\ No newline at end of file
...@@ -21,6 +21,7 @@ class ContentType: ...@@ -21,6 +21,7 @@ class ContentType:
TEXT = 'text' TEXT = 'text'
INTERLINE_EQUATION = 'interline_equation' INTERLINE_EQUATION = 'interline_equation'
INLINE_EQUATION = 'inline_equation' INLINE_EQUATION = 'inline_equation'
EQUATION = 'equation'
class CategoryId: class CategoryId:
......
# Copyright (c) Opendatalab. All rights reserved. # Copyright (c) Opendatalab. All rights reserved.
from loguru import logger from loguru import logger
from openai import OpenAI from openai import OpenAI
import ast import json_repair
from mineru.backend.pipeline.pipeline_middle_json_mkcontent import merge_para_with_text from mineru.backend.pipeline.pipeline_middle_json_mkcontent import merge_para_with_text
...@@ -20,14 +20,19 @@ def llm_aided_title(page_info_list, title_aided_config): ...@@ -20,14 +20,19 @@ def llm_aided_title(page_info_list, title_aided_config):
if block["type"] == "title": if block["type"] == "title":
origin_title_list.append(block) origin_title_list.append(block)
title_text = merge_para_with_text(block) title_text = merge_para_with_text(block)
page_line_height_list = []
for line in block['lines']: if 'line_avg_height' in block:
bbox = line['bbox'] line_avg_height = block['line_avg_height']
page_line_height_list.append(int(bbox[3] - bbox[1]))
if len(page_line_height_list) > 0:
line_avg_height = sum(page_line_height_list) / len(page_line_height_list)
else: else:
line_avg_height = int(block['bbox'][3] - block['bbox'][1]) title_block_line_height_list = []
for line in block['lines']:
bbox = line['bbox']
title_block_line_height_list.append(int(bbox[3] - bbox[1]))
if len(title_block_line_height_list) > 0:
line_avg_height = sum(title_block_line_height_list) / len(title_block_line_height_list)
else:
line_avg_height = int(block['bbox'][3] - block['bbox'][1])
title_dict[f"{i}"] = [title_text, line_avg_height, int(page_info['page_idx']) + 1] title_dict[f"{i}"] = [title_text, line_avg_height, int(page_info['page_idx']) + 1]
i += 1 i += 1
# logger.info(f"Title list: {title_dict}") # logger.info(f"Title list: {title_dict}")
...@@ -85,13 +90,17 @@ Corrected title list: ...@@ -85,13 +90,17 @@ Corrected title list:
messages=[ messages=[
{'role': 'user', 'content': title_optimize_prompt}], {'role': 'user', 'content': title_optimize_prompt}],
temperature=0.7, temperature=0.7,
stream=True,
) )
content = completion.choices[0].message.content.strip() content_pieces = []
for chunk in completion:
if chunk.choices and chunk.choices[0].delta.content is not None:
content_pieces.append(chunk.choices[0].delta.content)
content = "".join(content_pieces).strip()
# logger.info(f"Title completion: {content}") # logger.info(f"Title completion: {content}")
if "</think>" in content: if "</think>" in content:
idx = content.index("</think>") + len("</think>") idx = content.index("</think>") + len("</think>")
content = content[idx:].strip() content = content[idx:].strip()
import json_repair
dict_completion = json_repair.loads(content) dict_completion = json_repair.loads(content)
dict_completion = {int(k): int(v) for k, v in dict_completion.items()} dict_completion = {int(k): int(v) for k, v in dict_completion.items()}
......
...@@ -5,9 +5,11 @@ import numpy as np ...@@ -5,9 +5,11 @@ import numpy as np
class OcrConfidence: class OcrConfidence:
min_confidence = 0.68 min_confidence = 0.5
min_width = 3 min_width = 3
LINE_WIDTH_TO_HEIGHT_RATIO_THRESHOLD = 4 # 一般情况下,行宽度超过高度4倍时才是一个正常的横向文本块
def merge_spans_to_line(spans, threshold=0.6): def merge_spans_to_line(spans, threshold=0.6):
if len(spans) == 0: if len(spans) == 0:
...@@ -20,7 +22,7 @@ def merge_spans_to_line(spans, threshold=0.6): ...@@ -20,7 +22,7 @@ def merge_spans_to_line(spans, threshold=0.6):
current_line = [spans[0]] current_line = [spans[0]]
for span in spans[1:]: for span in spans[1:]:
# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行 # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold): if _is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
current_line.append(span) current_line.append(span)
else: else:
# 否则,开始新行 # 否则,开始新行
...@@ -33,9 +35,9 @@ def merge_spans_to_line(spans, threshold=0.6): ...@@ -33,9 +35,9 @@ def merge_spans_to_line(spans, threshold=0.6):
return lines return lines
def __is_overlaps_y_exceeds_threshold(bbox1, def _is_overlaps_y_exceeds_threshold(bbox1,
bbox2, bbox2,
overlap_ratio_threshold=0.8): overlap_ratio_threshold=0.8):
"""检查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过80%""" """检查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过80%"""
_, y0_1, _, y1_1 = bbox1 _, y0_1, _, y1_1 = bbox1
_, y0_2, _, y1_2 = bbox2 _, y0_2, _, y1_2 = bbox2
...@@ -45,7 +47,21 @@ def __is_overlaps_y_exceeds_threshold(bbox1, ...@@ -45,7 +47,21 @@ def __is_overlaps_y_exceeds_threshold(bbox1,
# max_height = max(height1, height2) # max_height = max(height1, height2)
min_height = min(height1, height2) min_height = min(height1, height2)
return (overlap / min_height) > overlap_ratio_threshold return (overlap / min_height) > overlap_ratio_threshold if min_height > 0 else False
def _is_overlaps_x_exceeds_threshold(bbox1,
bbox2,
overlap_ratio_threshold=0.8):
"""检查两个bbox在x轴上是否有重叠,并且该重叠区域的宽度占两个bbox宽度更低的那个超过指定阈值"""
x0_1, _, x1_1, _ = bbox1
x0_2, _, x1_2, _ = bbox2
overlap = max(0, min(x1_1, x1_2) - max(x0_1, x0_2))
width1, width2 = x1_1 - x0_1, x1_2 - x0_2
min_width = min(width1, width2)
return (overlap / min_width) > overlap_ratio_threshold if min_width > 0 else False
def img_decode(content: bytes): def img_decode(content: bytes):
...@@ -178,7 +194,7 @@ def update_det_boxes(dt_boxes, mfd_res): ...@@ -178,7 +194,7 @@ def update_det_boxes(dt_boxes, mfd_res):
masks_list = [] masks_list = []
for mf_box in mfd_res: for mf_box in mfd_res:
mf_bbox = mf_box['bbox'] mf_bbox = mf_box['bbox']
if __is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox): if _is_overlaps_y_exceeds_threshold(text_bbox, mf_bbox):
masks_list.append([mf_bbox[0], mf_bbox[2]]) masks_list.append([mf_bbox[0], mf_bbox[2]])
text_x_range = [text_bbox[0], text_bbox[2]] text_x_range = [text_bbox[0], text_bbox[2]]
text_remove_mask_range = remove_intervals(text_x_range, masks_list) text_remove_mask_range = remove_intervals(text_x_range, masks_list)
...@@ -266,12 +282,27 @@ def merge_det_boxes(dt_boxes): ...@@ -266,12 +282,27 @@ def merge_det_boxes(dt_boxes):
for span in line: for span in line:
line_bbox_list.append(span['bbox']) line_bbox_list.append(span['bbox'])
# Merge overlapping text regions within the same line # 计算整行的宽度和高度
merged_spans = merge_overlapping_spans(line_bbox_list) min_x = min(bbox[0] for bbox in line_bbox_list)
max_x = max(bbox[2] for bbox in line_bbox_list)
min_y = min(bbox[1] for bbox in line_bbox_list)
max_y = max(bbox[3] for bbox in line_bbox_list)
line_width = max_x - min_x
line_height = max_y - min_y
# 只有当行宽度超过高度4倍时才进行合并
if line_width > line_height * LINE_WIDTH_TO_HEIGHT_RATIO_THRESHOLD:
# Convert the merged text regions back to point format and add them to the new detection box list # Merge overlapping text regions within the same line
for span in merged_spans: merged_spans = merge_overlapping_spans(line_bbox_list)
new_dt_boxes.append(bbox_to_points(span))
# Convert the merged text regions back to point format and add them to the new detection box list
for span in merged_spans:
new_dt_boxes.append(bbox_to_points(span))
else:
# 不进行合并,直接添加原始区域
for bbox in line_bbox_list:
new_dt_boxes.append(bbox_to_points(bbox))
new_dt_boxes.extend(angle_boxes_list) new_dt_boxes.extend(angle_boxes_list)
......
# Copyright (c) Opendatalab. All rights reserved. # Copyright (c) Opendatalab. All rights reserved.
from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from mineru.utils.enum_class import BlockType, ContentType from mineru.utils.enum_class import BlockType, ContentType
from mineru.utils.ocr_utils import __is_overlaps_y_exceeds_threshold from mineru.utils.ocr_utils import _is_overlaps_y_exceeds_threshold, _is_overlaps_x_exceeds_threshold
VERTICAL_SPAN_HEIGHT_TO_WIDTH_RATIO_THRESHOLD = 2
VERTICAL_SPAN_IN_BLOCK_THRESHOLD = 0.8
def fill_spans_in_blocks(blocks, spans, radio): def fill_spans_in_blocks(blocks, spans, radio):
"""将allspans中的span按位置关系,放入blocks中.""" """将allspans中的span按位置关系,放入blocks中."""
...@@ -71,8 +73,26 @@ def fix_text_block(block): ...@@ -71,8 +73,26 @@ def fix_text_block(block):
for span in block['spans']: for span in block['spans']:
if span['type'] == ContentType.INTERLINE_EQUATION: if span['type'] == ContentType.INTERLINE_EQUATION:
span['type'] = ContentType.INLINE_EQUATION span['type'] = ContentType.INLINE_EQUATION
block_lines = merge_spans_to_line(block['spans'])
sort_block_lines = line_sort_spans_by_left_to_right(block_lines) # 假设block中的span超过80%的数量高度是宽度的两倍以上,则认为是纵向文本块
vertical_span_count = sum(
1 for span in block['spans']
if (span['bbox'][3] - span['bbox'][1]) / (span['bbox'][2] - span['bbox'][0]) > VERTICAL_SPAN_HEIGHT_TO_WIDTH_RATIO_THRESHOLD
)
total_span_count = len(block['spans'])
if total_span_count == 0:
vertical_ratio = 0
else:
vertical_ratio = vertical_span_count / total_span_count
if vertical_ratio > VERTICAL_SPAN_IN_BLOCK_THRESHOLD:
# 如果是纵向文本块,则按纵向lines处理
block_lines = merge_spans_to_vertical_line(block['spans'])
sort_block_lines = vertical_line_sort_spans_from_top_to_bottom(block_lines)
else:
block_lines = merge_spans_to_line(block['spans'])
sort_block_lines = line_sort_spans_by_left_to_right(block_lines)
block['lines'] = sort_block_lines block['lines'] = sort_block_lines
del block['spans'] del block['spans']
return block return block
...@@ -103,7 +123,7 @@ def merge_spans_to_line(spans, threshold=0.6): ...@@ -103,7 +123,7 @@ def merge_spans_to_line(spans, threshold=0.6):
continue continue
# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行 # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold): if _is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
current_line.append(span) current_line.append(span)
else: else:
# 否则,开始新行 # 否则,开始新行
...@@ -117,6 +137,44 @@ def merge_spans_to_line(spans, threshold=0.6): ...@@ -117,6 +137,44 @@ def merge_spans_to_line(spans, threshold=0.6):
return lines return lines
def merge_spans_to_vertical_line(spans, threshold=0.6):
"""将纵向文本的spans合并成纵向lines(从右向左阅读)"""
if len(spans) == 0:
return []
else:
# 按照x2坐标从大到小排序(从右向左)
spans.sort(key=lambda span: span['bbox'][2], reverse=True)
vertical_lines = []
current_line = [spans[0]]
for span in spans[1:]:
# 特殊类型元素单独成列
if span['type'] in [
ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
ContentType.TABLE
] or any(s['type'] in [
ContentType.INTERLINE_EQUATION, ContentType.IMAGE,
ContentType.TABLE
] for s in current_line):
vertical_lines.append(current_line)
current_line = [span]
continue
# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if _is_overlaps_x_exceeds_threshold(span['bbox'], current_line[-1]['bbox'], threshold):
current_line.append(span)
else:
vertical_lines.append(current_line)
current_line = [span]
# 添加最后一列
if current_line:
vertical_lines.append(current_line)
return vertical_lines
# 将每一个line中的span从左到右排序 # 将每一个line中的span从左到右排序
def line_sort_spans_by_left_to_right(lines): def line_sort_spans_by_left_to_right(lines):
line_objects = [] line_objects = []
...@@ -136,6 +194,28 @@ def line_sort_spans_by_left_to_right(lines): ...@@ -136,6 +194,28 @@ def line_sort_spans_by_left_to_right(lines):
return line_objects return line_objects
def vertical_line_sort_spans_from_top_to_bottom(vertical_lines):
line_objects = []
for line in vertical_lines:
# 按照y0坐标排序(从上到下)
line.sort(key=lambda span: span['bbox'][1])
# 计算整个列的边界框
line_bbox = [
min(span['bbox'][0] for span in line), # x0
min(span['bbox'][1] for span in line), # y0
max(span['bbox'][2] for span in line), # x1
max(span['bbox'][3] for span in line), # y1
]
# 组装结果
line_objects.append({
'bbox': line_bbox,
'spans': line,
})
return line_objects
def fix_block_spans(block_with_spans): def fix_block_spans(block_with_spans):
fix_blocks = [] fix_blocks = []
for block in block_with_spans: for block in block_with_spans:
......
# Copyright (c) Opendatalab. All rights reserved. # Copyright (c) Opendatalab. All rights reserved.
import collections
import re import re
import statistics import statistics
...@@ -187,7 +188,7 @@ def txt_spans_extract(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded ...@@ -187,7 +188,7 @@ def txt_spans_extract(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded
span['chars'] = [] span['chars'] = []
new_spans.append(span) new_spans.append(span)
need_ocr_spans = fill_char_in_spans(new_spans, page_all_chars) need_ocr_spans = fill_char_in_spans(new_spans, page_all_chars, median_span_height)
"""对未填充的span进行ocr""" """对未填充的span进行ocr"""
if len(need_ocr_spans) > 0: if len(need_ocr_spans) > 0:
...@@ -208,14 +209,26 @@ def txt_spans_extract(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded ...@@ -208,14 +209,26 @@ def txt_spans_extract(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded
return spans return spans
def fill_char_in_spans(spans, all_chars): def fill_char_in_spans(spans, all_chars, median_span_height):
# 简单从上到下排一下序 # 简单从上到下排一下序
spans = sorted(spans, key=lambda x: x['bbox'][1]) spans = sorted(spans, key=lambda x: x['bbox'][1])
grid_size = median_span_height
grid = collections.defaultdict(list)
for i, span in enumerate(spans):
start_cell = int(span['bbox'][1] / grid_size)
end_cell = int(span['bbox'][3] / grid_size)
for cell_idx in range(start_cell, end_cell + 1):
grid[cell_idx].append(i)
for char in all_chars: for char in all_chars:
char_center_y = (char['bbox'][1] + char['bbox'][3]) / 2
cell_idx = int(char_center_y / grid_size)
candidate_span_indices = grid.get(cell_idx, [])
for span in spans: for span_idx in candidate_span_indices:
span = spans[span_idx]
if calculate_char_in_span(char['bbox'], span['bbox'], char['char']): if calculate_char_in_span(char['bbox'], span['bbox'], char['char']):
span['chars'].append(char) span['chars'].append(char)
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment