Unverified Commit 919280aa authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge branch 'dev' into multi_gpu_v2

parents ea9336c0 c6881d83
import re
import time
import cv2
import numpy as np
from loguru import logger
from mineru.backend.pipeline.model_init import AtomModelSingleton
from mineru.utils.config_reader import get_llm_aided_config
from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.enum_class import BlockType, ContentType
from mineru.utils.enum_class import ContentType
from mineru.utils.hash_utils import str_md5
from mineru.backend.vlm.vlm_magic_model import MagicModel
from mineru.utils.llm_aided import llm_aided_title
from mineru.utils.pdf_image_tools import get_crop_img
from mineru.version import __version__
......@@ -23,6 +30,34 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
image_blocks = magic_model.get_image_blocks()
table_blocks = magic_model.get_table_blocks()
title_blocks = magic_model.get_title_blocks()
# 如果有标题优化需求,则对title_blocks截图det
llm_aided_config = get_llm_aided_config()
if llm_aided_config is not None:
title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None:
if title_aided_config.get('enable', False):
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang='ch_lite'
)
for title_block in title_blocks:
title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
title_np_img = np.array(title_pil_img)
# 给title_pil_img添加上下左右各50像素白边padding
title_np_img = cv2.copyMakeBorder(
title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
)
title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
if len(ocr_det_res) > 0:
# 计算所有res的平均高度
avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
title_block['line_avg_height'] = round(avg_height/scale)
text_blocks = magic_model.get_text_blocks()
interline_equation_blocks = magic_model.get_interline_equation_blocks()
......@@ -48,6 +83,19 @@ def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
image_dict = images_list[index]
page_info = token_to_page_info(token, image_dict, page, image_writer, index)
middle_json["pdf_info"].append(page_info)
"""llm优化"""
llm_aided_config = get_llm_aided_config()
if llm_aided_config is not None:
"""标题优化"""
title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None:
if title_aided_config.get('enable', False):
llm_aided_title_start_time = time.time()
llm_aided_title(middle_json["pdf_info"], title_aided_config)
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
# 关闭pdf文档
pdf_doc.close()
return middle_json
......
......@@ -25,6 +25,7 @@ class ModelSingleton:
backend: str,
model_path: str | None,
server_url: str | None,
**kwargs,
) -> BasePredictor:
key = (backend, model_path, server_url)
if key not in self._models:
......@@ -34,6 +35,7 @@ class ModelSingleton:
backend=backend,
model_path=model_path,
server_url=server_url,
**kwargs,
)
return self._models[key]
......@@ -45,9 +47,10 @@ def doc_analyze(
backend="transformers",
model_path: str | None = None,
server_url: str | None = None,
**kwargs,
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url)
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
# load_images_start = time.time()
images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
......@@ -71,19 +74,20 @@ async def aio_doc_analyze(
backend="transformers",
model_path: str | None = None,
server_url: str | None = None,
**kwargs,
):
if predictor is None:
predictor = ModelSingleton().get_model(backend, model_path, server_url)
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
load_images_start = time.time()
# load_images_start = time.time()
images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
load_images_time = round(time.time() - load_images_start, 2)
logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
# load_images_time = round(time.time() - load_images_start, 2)
# logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
infer_start = time.time()
# infer_start = time.time()
results = await predictor.aio_batch_predict(images=images_base64_list)
infer_time = round(time.time() - infer_start, 2)
logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
# infer_time = round(time.time() - infer_start, 2)
# logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
return middle_json
return middle_json, results
import re
from typing import Literal
from loguru import logger
from mineru.utils.boxbase import bbox_distance, is_in
from mineru.utils.enum_class import ContentType, BlockType, SplitFlag
from mineru.backend.vlm.vlm_middle_json_mkcontent import merge_para_with_text
......@@ -22,25 +24,30 @@ class MagicModel:
# 解析每个块
for index, block_info in enumerate(block_infos):
block_bbox = block_info[0].strip()
x1, y1, x2, y2 = map(int, block_bbox.split())
x_1, y_1, x_2, y_2 = (
int(x1 * width / 1000),
int(y1 * height / 1000),
int(x2 * width / 1000),
int(y2 * height / 1000),
)
if x_2 < x_1:
x_1, x_2 = x_2, x_1
if y_2 < y_1:
y_1, y_2 = y_2, y_1
block_bbox = (x_1, y_1, x_2, y_2)
block_type = block_info[1].strip()
block_content = block_info[2].strip()
# print(f"坐标: {block_bbox}")
# print(f"类型: {block_type}")
# print(f"内容: {block_content}")
# print("-" * 50)
try:
x1, y1, x2, y2 = map(int, block_bbox.split())
x_1, y_1, x_2, y_2 = (
int(x1 * width / 1000),
int(y1 * height / 1000),
int(x2 * width / 1000),
int(y2 * height / 1000),
)
if x_2 < x_1:
x_1, x_2 = x_2, x_1
if y_2 < y_1:
y_1, y_2 = y_2, y_1
block_bbox = (x_1, y_1, x_2, y_2)
block_type = block_info[1].strip()
block_content = block_info[2].strip()
# print(f"坐标: {block_bbox}")
# print(f"类型: {block_type}")
# print(f"内容: {block_content}")
# print("-" * 50)
except Exception as e:
# 如果解析失败,可能是因为格式不正确,跳过这个块
logger.warning(f"Invalid block format: {block_info}, error: {e}")
continue
span_type = "unknown"
if block_type in [
......
from mineru.utils.config_reader import get_latex_delimiter_config
import os
from mineru.utils.config_reader import get_latex_delimiter_config, get_formula_enable, get_table_enable
from mineru.utils.enum_class import MakeMode, BlockType, ContentType
......@@ -16,7 +18,7 @@ display_right_delimiter = delimiters['display']['right']
inline_left_delimiter = delimiters['inline']['left']
inline_right_delimiter = delimiters['inline']['right']
def merge_para_with_text(para_block):
def merge_para_with_text(para_block, formula_enable=True, img_buket_path=''):
para_text = ''
for line in para_block['lines']:
for j, span in enumerate(line['spans']):
......@@ -27,7 +29,11 @@ def merge_para_with_text(para_block):
elif span_type == ContentType.INLINE_EQUATION:
content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}"
elif span_type == ContentType.INTERLINE_EQUATION:
content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
if formula_enable:
content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
else:
if span.get('image_path', ''):
content = f"![]({img_buket_path}/{span['image_path']})"
# content = content.strip()
if content:
if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
......@@ -39,13 +45,13 @@ def merge_para_with_text(para_block):
para_text += content
return para_text
def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''):
def mk_blocks_to_markdown(para_blocks, make_mode, formula_enable, table_enable, img_buket_path=''):
page_markdown = []
for para_block in para_blocks:
para_text = ''
para_type = para_block['type']
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.INTERLINE_EQUATION]:
para_text = merge_para_with_text(para_block)
para_text = merge_para_with_text(para_block, formula_enable=formula_enable, img_buket_path=img_buket_path)
elif para_type == BlockType.TITLE:
title_level = get_title_level(para_block)
para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
......@@ -95,10 +101,14 @@ def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''):
for span in line['spans']:
if span['type'] == ContentType.TABLE:
# if processed by table model
if span.get('html', ''):
para_text += f"\n{span['html']}\n"
elif span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
if table_enable:
if span.get('html', ''):
para_text += f"\n{span['html']}\n"
elif span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
else:
if span.get('image_path', ''):
para_text += f"![]({img_buket_path}/{span['image_path']})"
for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TABLE_FOOTNOTE:
para_text += '\n' + merge_para_with_text(block) + ' '
......@@ -120,25 +130,25 @@ def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
para_content = {}
if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX]:
para_content = {
'type': 'text',
'type': ContentType.TEXT,
'text': merge_para_with_text(para_block),
}
elif para_type == BlockType.TITLE:
title_level = get_title_level(para_block)
para_content = {
'type': 'text',
'type': ContentType.TEXT,
'text': merge_para_with_text(para_block),
}
if title_level != 0:
para_content['text_level'] = title_level
elif para_type == BlockType.INTERLINE_EQUATION:
para_content = {
'type': 'equation',
'type': ContentType.EQUATION,
'text': merge_para_with_text(para_block),
'text_format': 'latex',
}
elif para_type == BlockType.IMAGE:
para_content = {'type': 'image', 'img_path': '', 'img_caption': [], 'img_footnote': []}
para_content = {'type': ContentType.IMAGE, 'img_path': '', BlockType.IMAGE_CAPTION: [], BlockType.IMAGE_FOOTNOTE: []}
for block in para_block['blocks']:
if block['type'] == BlockType.IMAGE_BODY:
for line in block['lines']:
......@@ -147,11 +157,11 @@ def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.IMAGE_CAPTION:
para_content['img_caption'].append(merge_para_with_text(block))
para_content[BlockType.IMAGE_CAPTION].append(merge_para_with_text(block))
if block['type'] == BlockType.IMAGE_FOOTNOTE:
para_content['img_footnote'].append(merge_para_with_text(block))
para_content[BlockType.IMAGE_FOOTNOTE].append(merge_para_with_text(block))
elif para_type == BlockType.TABLE:
para_content = {'type': 'table', 'img_path': '', 'table_caption': [], 'table_footnote': []}
para_content = {'type': ContentType.TABLE, 'img_path': '', BlockType.TABLE_CAPTION: [], BlockType.TABLE_FOOTNOTE: []}
for block in para_block['blocks']:
if block['type'] == BlockType.TABLE_BODY:
for line in block['lines']:
......@@ -159,15 +169,15 @@ def make_blocks_to_content_list(para_block, img_buket_path, page_idx):
if span['type'] == ContentType.TABLE:
if span.get('html', ''):
para_content['table_body'] = f"{span['html']}"
para_content[BlockType.TABLE_BODY] = f"{span['html']}"
if span.get('image_path', ''):
para_content['img_path'] = f"{img_buket_path}/{span['image_path']}"
if block['type'] == BlockType.TABLE_CAPTION:
para_content['table_caption'].append(merge_para_with_text(block))
para_content[BlockType.TABLE_CAPTION].append(merge_para_with_text(block))
if block['type'] == BlockType.TABLE_FOOTNOTE:
para_content['table_footnote'].append(merge_para_with_text(block))
para_content[BlockType.TABLE_FOOTNOTE].append(merge_para_with_text(block))
para_content['page_idx'] = page_idx
......@@ -177,6 +187,10 @@ def union_make(pdf_info_dict: list,
make_mode: str,
img_buket_path: str = '',
):
formula_enable = get_formula_enable(os.getenv('MINERU_VLM_FORMULA_ENABLE', 'True').lower() == 'true')
table_enable = get_table_enable(os.getenv('MINERU_VLM_TABLE_ENABLE', 'True').lower() == 'true')
output_content = []
for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks')
......@@ -184,7 +198,7 @@ def union_make(pdf_info_dict: list,
if not paras_of_layout:
continue
if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
page_markdown = mk_blocks_to_markdown(paras_of_layout, make_mode, img_buket_path)
page_markdown = mk_blocks_to_markdown(paras_of_layout, make_mode, formula_enable, table_enable, img_buket_path)
output_content.extend(page_markdown)
elif make_mode == MakeMode.CONTENT_LIST:
for para_block in paras_of_layout:
......
......@@ -4,12 +4,14 @@ import click
from pathlib import Path
from loguru import logger
from mineru.utils.cli_parser import arg_parse
from mineru.utils.config_reader import get_device
from mineru.utils.model_utils import get_vram
from ..version import __version__
from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
@click.command()
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.pass_context
@click.version_option(__version__,
'--version',
'-v',
......@@ -60,7 +62,8 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
'-l',
'--lang',
'lang',
type=click.Choice(['ch', 'ch_server', 'ch_lite', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']),
type=click.Choice(['ch', 'ch_server', 'ch_lite', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka',
'latin', 'arabic', 'east_slavic', 'cyrillic', 'devanagari']),
help="""
Input the languages in the pdf (if known) to improve OCR accuracy. Optional.
Without languages specified, 'ch' will be used by default.
......@@ -136,7 +139,14 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
)
def main(input_path, output_dir, method, backend, lang, server_url, start_page_id, end_page_id, formula_enable, table_enable, device_mode, virtual_vram, model_source):
def main(
ctx,
input_path, output_dir, method, backend, lang, server_url,
start_page_id, end_page_id, formula_enable, table_enable,
device_mode, virtual_vram, model_source, **kwargs
):
kwargs.update(arg_parse(ctx))
if not backend.endswith('-client'):
def get_device_mode() -> str:
......@@ -179,11 +189,12 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i
p_lang_list=lang_list,
backend=backend,
parse_method=method,
p_formula_enable=formula_enable,
p_table_enable=table_enable,
formula_enable=formula_enable,
table_enable=table_enable,
server_url=server_url,
start_page_id=start_page_id,
end_page_id=end_page_id
end_page_id=end_page_id,
**kwargs,
)
except Exception as e:
logger.exception(e)
......
......@@ -14,9 +14,10 @@ from mineru.utils.enum_class import MakeMode
from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes
from mineru.backend.vlm.vlm_middle_json_mkcontent import union_make as vlm_union_make
from mineru.backend.vlm.vlm_analyze import doc_analyze as vlm_doc_analyze
from mineru.backend.vlm.vlm_analyze import aio_doc_analyze as aio_vlm_doc_analyze
pdf_suffixes = [".pdf"]
image_suffixes = [".png", ".jpeg", ".jpg"]
image_suffixes = [".png", ".jpeg", ".jpg", ".webp", ".gif"]
def read_fn(path):
......@@ -73,155 +74,318 @@ def convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id=0, end_page
return output_bytes
def do_parse(
output_dir,
pdf_file_names: list[str],
pdf_bytes_list: list[bytes],
p_lang_list: list[str],
backend="pipeline",
parse_method="auto",
p_formula_enable=True,
p_table_enable=True,
server_url=None,
f_draw_layout_bbox=True,
f_draw_span_bbox=True,
f_dump_md=True,
f_dump_middle_json=True,
f_dump_model_output=True,
f_dump_orig_pdf=True,
f_dump_content_list=True,
f_make_md_mode=MakeMode.MM_MD,
start_page_id=0,
end_page_id=None,
def _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id):
"""准备处理PDF字节数据"""
result = []
for pdf_bytes in pdf_bytes_list:
new_pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
result.append(new_pdf_bytes)
return result
def _process_output(
pdf_info,
pdf_bytes,
pdf_file_name,
local_md_dir,
local_image_dir,
md_writer,
f_draw_layout_bbox,
f_draw_span_bbox,
f_dump_orig_pdf,
f_dump_md,
f_dump_content_list,
f_dump_middle_json,
f_dump_model_output,
f_make_md_mode,
middle_json,
model_output=None,
is_pipeline=True
):
from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make
"""处理输出文件"""
if f_draw_layout_bbox:
draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
if f_draw_span_bbox:
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
if f_dump_orig_pdf:
md_writer.write(
f"{pdf_file_name}_origin.pdf",
pdf_bytes,
)
image_dir = str(os.path.basename(local_image_dir))
if f_dump_md:
make_func = pipeline_union_make if is_pipeline else vlm_union_make
md_content_str = make_func(pdf_info, f_make_md_mode, image_dir)
md_writer.write_string(
f"{pdf_file_name}.md",
md_content_str,
)
if f_dump_content_list:
make_func = pipeline_union_make if is_pipeline else vlm_union_make
content_list = make_func(pdf_info, MakeMode.CONTENT_LIST, image_dir)
md_writer.write_string(
f"{pdf_file_name}_content_list.json",
json.dumps(content_list, ensure_ascii=False, indent=4),
)
if f_dump_middle_json:
md_writer.write_string(
f"{pdf_file_name}_middle.json",
json.dumps(middle_json, ensure_ascii=False, indent=4),
)
if f_dump_model_output:
if is_pipeline:
md_writer.write_string(
f"{pdf_file_name}_model.json",
json.dumps(model_output, ensure_ascii=False, indent=4),
)
else:
output_text = ("\n" + "-" * 50 + "\n").join(model_output)
md_writer.write_string(
f"{pdf_file_name}_model_output.txt",
output_text,
)
logger.info(f"local output dir is {local_md_dir}")
def _process_pipeline(
output_dir,
pdf_file_names,
pdf_bytes_list,
p_lang_list,
parse_method,
p_formula_enable,
p_table_enable,
f_draw_layout_bbox,
f_draw_span_bbox,
f_dump_md,
f_dump_middle_json,
f_dump_model_output,
f_dump_orig_pdf,
f_dump_content_list,
f_make_md_mode,
):
"""处理pipeline后端逻辑"""
from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json
from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list = (
pipeline_doc_analyze(
pdf_bytes_list, p_lang_list, parse_method=parse_method,
formula_enable=p_formula_enable, table_enable=p_table_enable
)
)
for idx, model_list in enumerate(infer_results):
model_json = copy.deepcopy(model_list)
pdf_file_name = pdf_file_names[idx]
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
images_list = all_image_lists[idx]
pdf_doc = all_pdf_docs[idx]
_lang = lang_list[idx]
_ocr_enable = ocr_enabled_list[idx]
middle_json = pipeline_result_to_middle_json(
model_list, images_list, pdf_doc, image_writer,
_lang, _ocr_enable, p_formula_enable
)
pdf_info = middle_json["pdf_info"]
pdf_bytes = pdf_bytes_list[idx]
_process_output(
pdf_info, pdf_bytes, pdf_file_name, local_md_dir, local_image_dir,
md_writer, f_draw_layout_bbox, f_draw_span_bbox, f_dump_orig_pdf,
f_dump_md, f_dump_content_list, f_dump_middle_json, f_dump_model_output,
f_make_md_mode, middle_json, model_json, is_pipeline=True
)
async def _async_process_vlm(
output_dir,
pdf_file_names,
pdf_bytes_list,
backend,
f_draw_layout_bbox,
f_draw_span_bbox,
f_dump_md,
f_dump_middle_json,
f_dump_model_output,
f_dump_orig_pdf,
f_dump_content_list,
f_make_md_mode,
server_url=None,
**kwargs,
):
"""异步处理VLM后端逻辑"""
parse_method = "vlm"
f_draw_span_bbox = False
if not backend.endswith("client"):
server_url = None
for idx, pdf_bytes in enumerate(pdf_bytes_list):
pdf_file_name = pdf_file_names[idx]
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
middle_json, infer_result = await aio_vlm_doc_analyze(
pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url, **kwargs,
)
pdf_info = middle_json["pdf_info"]
_process_output(
pdf_info, pdf_bytes, pdf_file_name, local_md_dir, local_image_dir,
md_writer, f_draw_layout_bbox, f_draw_span_bbox, f_dump_orig_pdf,
f_dump_md, f_dump_content_list, f_dump_middle_json, f_dump_model_output,
f_make_md_mode, middle_json, infer_result, is_pipeline=False
)
def _process_vlm(
output_dir,
pdf_file_names,
pdf_bytes_list,
backend,
f_draw_layout_bbox,
f_draw_span_bbox,
f_dump_md,
f_dump_middle_json,
f_dump_model_output,
f_dump_orig_pdf,
f_dump_content_list,
f_make_md_mode,
server_url=None,
**kwargs,
):
"""同步处理VLM后端逻辑"""
parse_method = "vlm"
f_draw_span_bbox = False
if not backend.endswith("client"):
server_url = None
if backend == "pipeline":
from mineru.backend.pipeline.pipeline_middle_json_mkcontent import union_make as pipeline_union_make
from mineru.backend.pipeline.model_json_to_middle_json import result_to_middle_json as pipeline_result_to_middle_json
from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc_analyze
for idx, pdf_bytes in enumerate(pdf_bytes_list):
new_pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
pdf_bytes_list[idx] = new_pdf_bytes
infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list = pipeline_doc_analyze(pdf_bytes_list, p_lang_list, parse_method=parse_method, formula_enable=p_formula_enable,table_enable=p_table_enable)
for idx, model_list in enumerate(infer_results):
model_json = copy.deepcopy(model_list)
pdf_file_name = pdf_file_names[idx]
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
images_list = all_image_lists[idx]
pdf_doc = all_pdf_docs[idx]
_lang = lang_list[idx]
_ocr_enable = ocr_enabled_list[idx]
middle_json = pipeline_result_to_middle_json(model_list, images_list, pdf_doc, image_writer, _lang, _ocr_enable, p_formula_enable)
pdf_info = middle_json["pdf_info"]
pdf_bytes = pdf_bytes_list[idx]
if f_draw_layout_bbox:
draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
if f_draw_span_bbox:
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
for idx, pdf_bytes in enumerate(pdf_bytes_list):
pdf_file_name = pdf_file_names[idx]
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
if f_dump_orig_pdf:
md_writer.write(
f"{pdf_file_name}_origin.pdf",
pdf_bytes,
)
middle_json, infer_result = vlm_doc_analyze(
pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url, **kwargs,
)
if f_dump_md:
image_dir = str(os.path.basename(local_image_dir))
md_content_str = pipeline_union_make(pdf_info, f_make_md_mode, image_dir)
md_writer.write_string(
f"{pdf_file_name}.md",
md_content_str,
)
pdf_info = middle_json["pdf_info"]
if f_dump_content_list:
image_dir = str(os.path.basename(local_image_dir))
content_list = pipeline_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir)
md_writer.write_string(
f"{pdf_file_name}_content_list.json",
json.dumps(content_list, ensure_ascii=False, indent=4),
)
_process_output(
pdf_info, pdf_bytes, pdf_file_name, local_md_dir, local_image_dir,
md_writer, f_draw_layout_bbox, f_draw_span_bbox, f_dump_orig_pdf,
f_dump_md, f_dump_content_list, f_dump_middle_json, f_dump_model_output,
f_make_md_mode, middle_json, infer_result, is_pipeline=False
)
if f_dump_middle_json:
md_writer.write_string(
f"{pdf_file_name}_middle.json",
json.dumps(middle_json, ensure_ascii=False, indent=4),
)
if f_dump_model_output:
md_writer.write_string(
f"{pdf_file_name}_model.json",
json.dumps(model_json, ensure_ascii=False, indent=4),
)
def do_parse(
output_dir,
pdf_file_names: list[str],
pdf_bytes_list: list[bytes],
p_lang_list: list[str],
backend="pipeline",
parse_method="auto",
formula_enable=True,
table_enable=True,
server_url=None,
f_draw_layout_bbox=True,
f_draw_span_bbox=True,
f_dump_md=True,
f_dump_middle_json=True,
f_dump_model_output=True,
f_dump_orig_pdf=True,
f_dump_content_list=True,
f_make_md_mode=MakeMode.MM_MD,
start_page_id=0,
end_page_id=None,
**kwargs,
):
# 预处理PDF字节数据
pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id)
logger.info(f"local output dir is {local_md_dir}")
if backend == "pipeline":
_process_pipeline(
output_dir, pdf_file_names, pdf_bytes_list, p_lang_list,
parse_method, formula_enable, table_enable,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode
)
else:
if backend.startswith("vlm-"):
backend = backend[4:]
f_draw_span_bbox = False
parse_method = "vlm"
for idx, pdf_bytes in enumerate(pdf_bytes_list):
pdf_file_name = pdf_file_names[idx]
pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url)
pdf_info = middle_json["pdf_info"]
if f_draw_layout_bbox:
draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_layout.pdf")
if f_draw_span_bbox:
draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, f"{pdf_file_name}_span.pdf")
if f_dump_orig_pdf:
md_writer.write(
f"{pdf_file_name}_origin.pdf",
pdf_bytes,
)
if f_dump_md:
image_dir = str(os.path.basename(local_image_dir))
md_content_str = vlm_union_make(pdf_info, f_make_md_mode, image_dir)
md_writer.write_string(
f"{pdf_file_name}.md",
md_content_str,
)
if f_dump_content_list:
image_dir = str(os.path.basename(local_image_dir))
content_list = vlm_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir)
md_writer.write_string(
f"{pdf_file_name}_content_list.json",
json.dumps(content_list, ensure_ascii=False, indent=4),
)
os.environ['MINERU_VLM_FORMULA_ENABLE'] = str(formula_enable)
os.environ['MINERU_VLM_TABLE_ENABLE'] = str(table_enable)
_process_vlm(
output_dir, pdf_file_names, pdf_bytes_list, backend,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode,
server_url, **kwargs,
)
async def aio_do_parse(
output_dir,
pdf_file_names: list[str],
pdf_bytes_list: list[bytes],
p_lang_list: list[str],
backend="pipeline",
parse_method="auto",
formula_enable=True,
table_enable=True,
server_url=None,
f_draw_layout_bbox=True,
f_draw_span_bbox=True,
f_dump_md=True,
f_dump_middle_json=True,
f_dump_model_output=True,
f_dump_orig_pdf=True,
f_dump_content_list=True,
f_make_md_mode=MakeMode.MM_MD,
start_page_id=0,
end_page_id=None,
**kwargs,
):
# 预处理PDF字节数据
pdf_bytes_list = _prepare_pdf_bytes(pdf_bytes_list, start_page_id, end_page_id)
if f_dump_middle_json:
md_writer.write_string(
f"{pdf_file_name}_middle.json",
json.dumps(middle_json, ensure_ascii=False, indent=4),
)
if backend == "pipeline":
# pipeline模式暂不支持异步,使用同步处理方式
_process_pipeline(
output_dir, pdf_file_names, pdf_bytes_list, p_lang_list,
parse_method, formula_enable, table_enable,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode
)
else:
if backend.startswith("vlm-"):
backend = backend[4:]
if f_dump_model_output:
model_output = ("\n" + "-" * 50 + "\n").join(infer_result)
md_writer.write_string(
f"{pdf_file_name}_model_output.txt",
model_output,
)
os.environ['MINERU_VLM_FORMULA_ENABLE'] = str(formula_enable)
os.environ['MINERU_VLM_TABLE_ENABLE'] = str(table_enable)
logger.info(f"local output dir is {local_md_dir}")
await _async_process_vlm(
output_dir, pdf_file_names, pdf_bytes_list, backend,
f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode,
server_url, **kwargs,
)
......
import uuid
import os
import uvicorn
import click
from pathlib import Path
from glob import glob
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import JSONResponse
from typing import List, Optional
from loguru import logger
from base64 import b64encode
from mineru.cli.common import aio_do_parse, read_fn, pdf_suffixes, image_suffixes
from mineru.utils.cli_parser import arg_parse
from mineru.version import __version__
app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=1000)
def encode_image(image_path: str) -> str:
"""Encode image using base64"""
with open(image_path, "rb") as f:
return b64encode(f.read()).decode()
def get_infer_result(file_suffix_identifier: str, pdf_name: str, parse_dir: str) -> Optional[str]:
"""从结果文件中读取推理结果"""
result_file_path = os.path.join(parse_dir, f"{pdf_name}{file_suffix_identifier}")
if os.path.exists(result_file_path):
with open(result_file_path, "r", encoding="utf-8") as fp:
return fp.read()
return None
@app.post(path="/file_parse",)
async def parse_pdf(
files: List[UploadFile] = File(...),
output_dir: str = Form("./output"),
lang_list: List[str] = Form(["ch"]),
backend: str = Form("pipeline"),
parse_method: str = Form("auto"),
formula_enable: bool = Form(True),
table_enable: bool = Form(True),
server_url: Optional[str] = Form(None),
return_md: bool = Form(True),
return_middle_json: bool = Form(False),
return_model_output: bool = Form(False),
return_content_list: bool = Form(False),
return_images: bool = Form(False),
start_page_id: int = Form(0),
end_page_id: int = Form(99999),
):
# 获取命令行配置参数
config = getattr(app.state, "config", {})
try:
# 创建唯一的输出目录
unique_dir = os.path.join(output_dir, str(uuid.uuid4()))
os.makedirs(unique_dir, exist_ok=True)
# 处理上传的PDF文件
pdf_file_names = []
pdf_bytes_list = []
for file in files:
content = await file.read()
file_path = Path(file.filename)
# 如果是图像文件或PDF,使用read_fn处理
if file_path.suffix.lower() in pdf_suffixes + image_suffixes:
# 创建临时文件以便使用read_fn
temp_path = Path(unique_dir) / file_path.name
with open(temp_path, "wb") as f:
f.write(content)
try:
pdf_bytes = read_fn(temp_path)
pdf_bytes_list.append(pdf_bytes)
pdf_file_names.append(file_path.stem)
os.remove(temp_path) # 删除临时文件
except Exception as e:
return JSONResponse(
status_code=400,
content={"error": f"Failed to load file: {str(e)}"}
)
else:
return JSONResponse(
status_code=400,
content={"error": f"Unsupported file type: {file_path.suffix}"}
)
# 设置语言列表,确保与文件数量一致
actual_lang_list = lang_list
if len(actual_lang_list) != len(pdf_file_names):
# 如果语言列表长度不匹配,使用第一个语言或默认"ch"
actual_lang_list = [actual_lang_list[0] if actual_lang_list else "ch"] * len(pdf_file_names)
# 调用异步处理函数
await aio_do_parse(
output_dir=unique_dir,
pdf_file_names=pdf_file_names,
pdf_bytes_list=pdf_bytes_list,
p_lang_list=actual_lang_list,
backend=backend,
parse_method=parse_method,
formula_enable=formula_enable,
table_enable=table_enable,
server_url=server_url,
f_draw_layout_bbox=False,
f_draw_span_bbox=False,
f_dump_md=return_md,
f_dump_middle_json=return_middle_json,
f_dump_model_output=return_model_output,
f_dump_orig_pdf=False,
f_dump_content_list=return_content_list,
start_page_id=start_page_id,
end_page_id=end_page_id,
**config
)
# 构建结果路径
result_dict = {}
for pdf_name in pdf_file_names:
result_dict[pdf_name] = {}
data = result_dict[pdf_name]
if backend.startswith("pipeline"):
parse_dir = os.path.join(unique_dir, pdf_name, parse_method)
else:
parse_dir = os.path.join(unique_dir, pdf_name, "vlm")
if os.path.exists(parse_dir):
if return_md:
data["md_content"] = get_infer_result(".md", pdf_name, parse_dir)
if return_middle_json:
data["middle_json"] = get_infer_result("_middle.json", pdf_name, parse_dir)
if return_model_output:
if backend.startswith("pipeline"):
data["model_output"] = get_infer_result("_model.json", pdf_name, parse_dir)
else:
data["model_output"] = get_infer_result("_model_output.txt", pdf_name, parse_dir)
if return_content_list:
data["content_list"] = get_infer_result("_content_list.json", pdf_name, parse_dir)
if return_images:
image_paths = glob(f"{parse_dir}/images/*.jpg")
data["images"] = {
os.path.basename(
image_path
): f"data:image/jpeg;base64,{encode_image(image_path)}"
for image_path in image_paths
}
return JSONResponse(
status_code=200,
content={
"backend": backend,
"version": __version__,
"results": result_dict
}
)
except Exception as e:
logger.exception(e)
return JSONResponse(
status_code=500,
content={"error": f"Failed to process file: {str(e)}"}
)
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.pass_context
@click.option('--host', default='127.0.0.1', help='Server host (default: 127.0.0.1)')
@click.option('--port', default=8000, type=int, help='Server port (default: 8000)')
@click.option('--reload', is_flag=True, help='Enable auto-reload (development mode)')
def main(ctx, host, port, reload, **kwargs):
kwargs.update(arg_parse(ctx))
# 将配置参数存储到应用状态中
app.state.config = kwargs
"""启动MinerU FastAPI服务器的命令行入口"""
print(f"Start MinerU FastAPI Service: http://{host}:{port}")
print("The API documentation can be accessed at the following address:")
print(f"- Swagger UI: http://{host}:{port}/docs")
print(f"- ReDoc: http://{host}:{port}/redoc")
uvicorn.run(
"mineru.cli.fast_api:app",
host=host,
port=port,
reload=reload
)
if __name__ == "__main__":
main()
\ No newline at end of file
......@@ -7,38 +7,47 @@ import time
import zipfile
from pathlib import Path
import click
import gradio as gr
from gradio_pdf import PDF
from loguru import logger
from mineru.cli.common import prepare_env, do_parse, read_fn
from mineru.cli.common import prepare_env, read_fn, aio_do_parse, pdf_suffixes, image_suffixes
from mineru.utils.cli_parser import arg_parse
from mineru.utils.hash_utils import str_sha256
def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, formula_enable, table_enable, language):
async def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, formula_enable, table_enable, language, backend, url):
os.makedirs(output_dir, exist_ok=True)
try:
file_name = f'{str(Path(doc_path).stem)}_{time.strftime("%y%m%d_%H%M%S")}'
file_name = f'{safe_stem(Path(doc_path).stem)}_{time.strftime("%y%m%d_%H%M%S")}'
pdf_data = read_fn(doc_path)
if is_ocr:
parse_method = 'ocr'
else:
parse_method = 'auto'
if backend.startswith("vlm"):
parse_method = "vlm"
local_image_dir, local_md_dir = prepare_env(output_dir, file_name, parse_method)
do_parse(
await aio_do_parse(
output_dir=output_dir,
pdf_file_names=[file_name],
pdf_bytes_list=[pdf_data],
p_lang_list=[language],
parse_method=parse_method,
end_page_id=end_page_id,
p_formula_enable=formula_enable,
p_table_enable=table_enable,
formula_enable=formula_enable,
table_enable=table_enable,
backend=backend,
server_url=url,
)
return local_md_dir, file_name
except Exception as e:
logger.exception(e)
return None
def compress_directory_to_zip(directory_path, output_zip_path):
......@@ -85,16 +94,16 @@ def replace_image_with_base64(markdown_text, image_dir_path):
return re.sub(pattern, replace, markdown_text)
def to_markdown(file_path, end_pages, is_ocr, formula_enable, table_enable, language):
async def to_markdown(file_path, end_pages=10, is_ocr=False, formula_enable=True, table_enable=True, language="ch", backend="pipeline", url=None):
file_path = to_pdf(file_path)
# 获取识别的md文件以及压缩包文件路径
local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr, formula_enable, table_enable, language)
local_md_dir, file_name = await parse_pdf(file_path, './output', end_pages - 1, is_ocr, formula_enable, table_enable, language, backend, url)
archive_zip_path = os.path.join('./output', str_sha256(local_md_dir) + '.zip')
zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
if zip_archive_success == 0:
logger.info('压缩成功')
logger.info('Compression successful')
else:
logger.error('压缩失败')
logger.error('Compression failed')
md_path = os.path.join(local_md_dir, file_name + '.md')
with open(md_path, 'r', encoding='utf-8') as f:
txt_content = f.read()
......@@ -112,9 +121,9 @@ latex_delimiters = [
{'left': '\\[', 'right': '\\]', 'display': True},
]
with open('header.html', 'r') as file:
header = file.read()
header_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'resources', 'header.html')
with open(header_path, 'r') as header_file:
header = header_file.read()
latin_lang = [
......@@ -125,15 +134,16 @@ latin_lang = [
]
arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126
'rs_cyrillic', 'bg', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126
'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
]
east_slavic_lang = ["ru", "be", "uk"]
devanagari_lang = [
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', # noqa: E126
'sa', 'bgc'
]
other_lang = ['ch', 'ch_lite', 'ch_server', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']
add_lang = ['latin', 'arabic', 'cyrillic', 'devanagari']
add_lang = ['latin', 'arabic', 'east_slavic', 'cyrillic', 'devanagari']
# all_lang = ['', 'auto']
all_lang = []
......@@ -167,33 +177,125 @@ def to_pdf(file_path):
return tmp_file_path
if __name__ == '__main__':
# 更新界面函数
def update_interface(backend_choice):
if backend_choice in ["vlm-transformers", "vlm-sglang-engine"]:
return gr.update(visible=False), gr.update(visible=False)
elif backend_choice in ["vlm-sglang-client"]:
return gr.update(visible=True), gr.update(visible=False)
elif backend_choice in ["pipeline"]:
return gr.update(visible=False), gr.update(visible=True)
else:
pass
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.pass_context
@click.option(
'--enable-example',
'example_enable',
type=bool,
help="Enable example files for input."
"The example files to be input need to be placed in the `example` folder within the directory where the command is currently executed.",
default=True,
)
@click.option(
'--enable-sglang-engine',
'sglang_engine_enable',
type=bool,
help="Enable SgLang engine backend for faster processing.",
default=False,
)
@click.option(
'--enable-api',
'api_enable',
type=bool,
help="Enable gradio API for serving the application.",
default=True,
)
@click.option(
'--max-convert-pages',
'max_convert_pages',
type=int,
help="Set the maximum number of pages to convert from PDF to Markdown.",
default=1000,
)
@click.option(
'--server-name',
'server_name',
type=str,
help="Set the server name for the Gradio app.",
default=None,
)
@click.option(
'--server-port',
'server_port',
type=int,
help="Set the server port for the Gradio app.",
default=None,
)
def main(ctx,
example_enable, sglang_engine_enable, api_enable, max_convert_pages,
server_name, server_port, **kwargs
):
kwargs.update(arg_parse(ctx))
if sglang_engine_enable:
try:
print("Start init SgLang engine...")
from mineru.backend.vlm.vlm_analyze import ModelSingleton
model_singleton = ModelSingleton()
predictor = model_singleton.get_model(
"sglang-engine",
None,
None,
**kwargs
)
print("SgLang engine init successfully.")
except Exception as e:
logger.exception(e)
suffixes = pdf_suffixes + image_suffixes
with gr.Blocks() as demo:
gr.HTML(header)
with gr.Row():
with gr.Column(variant='panel', scale=5):
with gr.Row():
file = gr.File(label='Please upload a PDF or image', file_types=['.pdf', '.png', '.jpeg', '.jpg'])
input_file = gr.File(label='Please upload a PDF or image', file_types=suffixes)
with gr.Row():
max_pages = gr.Slider(1, max_convert_pages, int(max_convert_pages/2), step=1, label='Max convert pages')
with gr.Row():
if sglang_engine_enable:
drop_list = ["pipeline", "vlm-sglang-engine"]
preferred_option = "vlm-sglang-engine"
else:
drop_list = ["pipeline", "vlm-transformers", "vlm-sglang-client"]
preferred_option = "pipeline"
backend = gr.Dropdown(drop_list, label="Backend", value=preferred_option)
with gr.Row(visible=False) as client_options:
url = gr.Textbox(label='Server URL', value='http://localhost:30000', placeholder='http://localhost:30000')
with gr.Row(equal_height=True):
with gr.Column(scale=4):
max_pages = gr.Slider(1, 20, 10, step=1, label='Max convert pages')
with gr.Column(scale=1):
with gr.Column():
gr.Markdown("**Recognition Options:**")
formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
table_enable = gr.Checkbox(label='Enable table recognition', value=True)
with gr.Column(visible=False) as ocr_options:
language = gr.Dropdown(all_lang, label='Language', value='ch')
with gr.Row():
is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
table_enable = gr.Checkbox(label='Enable table recognition(test)', value=True)
is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
with gr.Row():
change_bu = gr.Button('Convert')
clear_bu = gr.ClearButton(value='Clear')
pdf_show = PDF(label='PDF preview', interactive=False, visible=True, height=800)
with gr.Accordion('Examples:'):
example_root = os.path.join(os.path.dirname(__file__), 'examples')
gr.Examples(
examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
_.endswith('pdf')],
inputs=file
)
if example_enable:
example_root = os.path.join(os.getcwd(), 'examples')
if os.path.exists(example_root):
with gr.Accordion('Examples:'):
gr.Examples(
examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
_.endswith(tuple(suffixes))],
inputs=input_file
)
with gr.Column(variant='panel', scale=5):
output_file = gr.File(label='convert result', interactive=False)
......@@ -204,9 +306,38 @@ if __name__ == '__main__':
line_breaks=True)
with gr.Tab('Markdown text'):
md_text = gr.TextArea(lines=45, show_copy_button=True)
file.change(fn=to_pdf, inputs=file, outputs=pdf_show)
change_bu.click(fn=to_markdown, inputs=[file, max_pages, is_ocr, formula_enable, table_enable, language],
outputs=[md, md_text, output_file, pdf_show])
clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr])
demo.launch(server_name='0.0.0.0')
# 添加事件处理
backend.change(
fn=update_interface,
inputs=[backend],
outputs=[client_options, ocr_options],
api_name=False
)
# 添加demo.load事件,在页面加载时触发一次界面更新
demo.load(
fn=update_interface,
inputs=[backend],
outputs=[client_options, ocr_options],
api_name=False
)
clear_bu.add([input_file, md, pdf_show, md_text, output_file, is_ocr])
if api_enable:
api_name = None
else:
api_name = False
input_file.change(fn=to_pdf, inputs=input_file, outputs=pdf_show, api_name=api_name)
change_bu.click(
fn=to_markdown,
inputs=[input_file, max_pages, is_ocr, formula_enable, table_enable, language, backend, url],
outputs=[md, md_text, output_file, pdf_show],
api_name=api_name
)
demo.launch(server_name=server_name, server_port=server_port, show_api=api_enable)
if __name__ == '__main__':
main()
\ No newline at end of file
......@@ -3,6 +3,7 @@ import os
import sys
import click
import requests
from loguru import logger
from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
......@@ -54,7 +55,32 @@ def configure_model(model_dir, model_type):
}
download_and_modify_json(json_url, config_file, json_mods)
print(f'The configuration file has been successfully configured, the path is: {config_file}')
logger.info(f'The configuration file has been successfully configured, the path is: {config_file}')
def download_pipeline_models():
"""下载Pipeline模型"""
model_paths = [
ModelPath.doclayout_yolo,
ModelPath.yolo_v8_mfd,
ModelPath.unimernet_small,
ModelPath.pytorch_paddle,
ModelPath.layout_reader,
ModelPath.slanet_plus
]
download_finish_path = ""
for model_path in model_paths:
logger.info(f"Downloading model: {model_path}")
download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode='pipeline')
logger.info(f"Pipeline models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, "pipeline")
def download_vlm_models():
"""下载VLM模型"""
download_finish_path = auto_download_and_get_model_root_path("/", repo_mode='vlm')
logger.info(f"VLM models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, "vlm")
@click.command()
......@@ -102,30 +128,7 @@ def download_models(model_source, model_type):
default='all'
)
click.echo(f"Downloading {model_type} model from {os.getenv('MINERU_MODEL_SOURCE', None)}...")
def download_pipeline_models():
"""下载Pipeline模型"""
model_paths = [
ModelPath.doclayout_yolo,
ModelPath.yolo_v8_mfd,
ModelPath.unimernet_small,
ModelPath.pytorch_paddle,
ModelPath.layout_reader,
ModelPath.slanet_plus
]
download_finish_path = ""
for model_path in model_paths:
click.echo(f"Downloading model: {model_path}")
download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode='pipeline')
click.echo(f"Pipeline models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, "pipeline")
def download_vlm_models():
"""下载VLM模型"""
download_finish_path = auto_download_and_get_model_root_path("/", repo_mode='vlm')
click.echo(f"VLM models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, "vlm")
logger.info(f"Downloading {model_type} model from {os.getenv('MINERU_MODEL_SOURCE', None)}...")
try:
if model_type == 'pipeline':
......@@ -140,7 +143,7 @@ def download_models(model_source, model_type):
sys.exit(1)
except Exception as e:
click.echo(f"Download failed: {str(e)}", err=True)
logger.exception(f"An error occurred while downloading models: {str(e)}")
sys.exit(1)
if __name__ == '__main__':
......
from typing import List, Dict, Union
from doclayout_yolo import YOLOv10
from tqdm import tqdm
import numpy as np
from PIL import Image
class DocLayoutYOLOModel(object):
def __init__(self, weight, device):
self.model = YOLOv10(weight)
class DocLayoutYOLOModel:
def __init__(
self,
weight: str,
device: str = "cuda",
imgsz: int = 1280,
conf: float = 0.1,
iou: float = 0.45,
):
self.model = YOLOv10(weight).to(device)
self.device = device
self.imgsz = imgsz
self.conf = conf
self.iou = iou
def predict(self, image):
def _parse_prediction(self, prediction) -> List[Dict]:
layout_res = []
doclayout_yolo_res = self.model.predict(
image,
imgsz=1280,
conf=0.10,
iou=0.45,
verbose=False, device=self.device
)[0]
for xyxy, conf, cla in zip(
doclayout_yolo_res.boxes.xyxy.cpu(),
doclayout_yolo_res.boxes.conf.cpu(),
doclayout_yolo_res.boxes.cls.cpu(),
# 容错处理
if not hasattr(prediction, "boxes") or prediction.boxes is None:
return layout_res
for xyxy, conf, cls in zip(
prediction.boxes.xyxy.cpu(),
prediction.boxes.conf.cpu(),
prediction.boxes.cls.cpu(),
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
"category_id": int(cla.item()),
coords = list(map(int, xyxy.tolist()))
xmin, ymin, xmax, ymax = coords
layout_res.append({
"category_id": int(cls.item()),
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
"score": round(float(conf.item()), 3),
}
layout_res.append(new_item)
})
return layout_res
def batch_predict(self, images: list, batch_size: int) -> list:
images_layout_res = []
# for index in range(0, len(images), batch_size):
for index in tqdm(range(0, len(images), batch_size), desc="Layout Predict"):
doclayout_yolo_res = [
image_res.cpu()
for image_res in self.model.predict(
images[index : index + batch_size],
imgsz=1280,
conf=0.10,
iou=0.45,
def predict(self, image: Union[np.ndarray, Image.Image]) -> List[Dict]:
prediction = self.model.predict(
image,
imgsz=self.imgsz,
conf=self.conf,
iou=self.iou,
verbose=False
)[0]
return self._parse_prediction(prediction)
def batch_predict(
self,
images: List[Union[np.ndarray, Image.Image]],
batch_size: int = 4
) -> List[List[Dict]]:
results = []
with tqdm(total=len(images), desc="Layout Predict") as pbar:
for idx in range(0, len(images), batch_size):
batch = images[idx: idx + batch_size]
predictions = self.model.predict(
batch,
imgsz=self.imgsz,
conf=self.conf,
iou=self.iou,
verbose=False,
device=self.device,
)
]
for image_res in doclayout_yolo_res:
layout_res = []
for xyxy, conf, cla in zip(
image_res.boxes.xyxy,
image_res.boxes.conf,
image_res.boxes.cls,
):
xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
new_item = {
"category_id": int(cla.item()),
"poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
"score": round(float(conf.item()), 3),
}
layout_res.append(new_item)
images_layout_res.append(layout_res)
return images_layout_res
for pred in predictions:
results.append(self._parse_prediction(pred))
pbar.update(len(batch))
return results
\ No newline at end of file
from typing import List, Union
from tqdm import tqdm
from ultralytics import YOLO
import numpy as np
from PIL import Image
class YOLOv8MFDModel(object):
def __init__(self, weight, device="cpu"):
self.mfd_model = YOLO(weight)
class YOLOv8MFDModel:
def __init__(
self,
weight: str,
device: str = "cpu",
imgsz: int = 1888,
conf: float = 0.25,
iou: float = 0.45,
):
self.model = YOLO(weight).to(device)
self.device = device
self.imgsz = imgsz
self.conf = conf
self.iou = iou
def predict(self, image):
mfd_res = self.mfd_model.predict(
image, imgsz=1888, conf=0.25, iou=0.45, verbose=False, device=self.device
)[0]
return mfd_res
def _run_predict(
self,
inputs: Union[np.ndarray, Image.Image, List],
is_batch: bool = False
) -> List:
preds = self.model.predict(
inputs,
imgsz=self.imgsz,
conf=self.conf,
iou=self.iou,
verbose=False,
device=self.device
)
return [pred.cpu() for pred in preds] if is_batch else preds[0].cpu()
def batch_predict(self, images: list, batch_size: int) -> list:
images_mfd_res = []
# for index in range(0, len(images), batch_size):
for index in tqdm(range(0, len(images), batch_size), desc="MFD Predict"):
mfd_res = [
image_res.cpu()
for image_res in self.mfd_model.predict(
images[index : index + batch_size],
imgsz=1888,
conf=0.25,
iou=0.45,
verbose=False,
device=self.device,
)
]
for image_res in mfd_res:
images_mfd_res.append(image_res)
return images_mfd_res
def predict(self, image: Union[np.ndarray, Image.Image]):
return self._run_predict(image)
def batch_predict(
self,
images: List[Union[np.ndarray, Image.Image]],
batch_size: int = 4
) -> List:
results = []
with tqdm(total=len(images), desc="MFD Predict") as pbar:
for idx in range(0, len(images), batch_size):
batch = images[idx: idx + batch_size]
batch_preds = self._run_predict(batch, is_batch=True)
results.extend(batch_preds)
pbar.update(len(batch))
return results
\ No newline at end of file
......@@ -26,9 +26,10 @@ latin_lang = [
]
arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126
'rs_cyrillic', 'bg', 'mn', 'abq', 'ady', 'kbd', 'ava', # noqa: E126
'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
]
east_slavic_lang = ["ru", "be", "uk"]
devanagari_lang = [
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom', # noqa: E126
'sa', 'bgc'
......@@ -58,7 +59,7 @@ class PytorchPaddleOCR(TextSystem):
device = get_device()
if device == 'cpu' and self.lang in ['ch', 'ch_server', 'japan', 'chinese_cht']:
logger.warning("The current device in use is CPU. To ensure the speed of parsing, the language is automatically switched to ch_lite.")
# logger.warning("The current device in use is CPU. To ensure the speed of parsing, the language is automatically switched to ch_lite.")
self.lang = 'ch_lite'
if self.lang in latin_lang:
......@@ -69,6 +70,8 @@ class PytorchPaddleOCR(TextSystem):
self.lang = 'cyrillic'
elif self.lang in devanagari_lang:
self.lang = 'devanagari'
elif self.lang in east_slavic_lang:
self.lang = 'east_slavic'
else:
pass
......
......@@ -490,3 +490,82 @@ devanagari_PP-OCRv3_rec_infer:
# out_channels: 169
fc_decay: 0.00001
korean_PP-OCRv5_rec_infer:
model_type: rec
algorithm: SVTR_HGNet
Transform:
Backbone:
name: PPLCNetV3
scale: 0.95
Head:
name: MultiHead
out_channels_list:
CTCLabelDecode: 11947
head_list:
- CTCHead:
Neck:
name: svtr
dims: 120
depth: 2
hidden_dims: 120
kernel_size: [ 1, 3 ]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
max_text_length: 25
latin_PP-OCRv5_rec_infer:
model_type: rec
algorithm: SVTR_HGNet
Transform:
Backbone:
name: PPLCNetV3
scale: 0.95
Head:
name: MultiHead
out_channels_list:
CTCLabelDecode: 504
head_list:
- CTCHead:
Neck:
name: svtr
dims: 120
depth: 2
hidden_dims: 120
kernel_size: [ 1, 3 ]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
max_text_length: 25
eslav_PP-OCRv5_rec_infer:
model_type: rec
algorithm: SVTR_HGNet
Transform:
Backbone:
name: PPLCNetV3
scale: 0.95
Head:
name: MultiHead
out_channels_list:
CTCLabelDecode: 519
head_list:
- CTCHead:
Neck:
name: svtr
dims: 120
depth: 2
hidden_dims: 120
kernel_size: [ 1, 3 ]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
max_text_length: 25
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
]
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
©
{
}
\
|
@
^
~
÷
·
±
®
Ω
¢
£
¥
𝑢
𝜓
ƒ
À
Á
Â
Ã
Ä
Å
Æ
Ç
È
É
Ê
Ë
Ì
Í
Î
Ï
Ð
Ñ
Ò
Ó
Ô
Õ
Ö
Ø
Ù
Ú
Û
Ü
Ý
Þ
à
á
â
ã
ä
å
æ
ç
è
é
ê
ë
ì
í
î
ï
ð
ñ
ò
ó
ô
õ
ö
ø
ù
ú
û
ü
ý
þ
ÿ
¡
¤
¦
§
¨
ª
«
¬
¯
°
²
³
´
µ
¸
¹
º
»
¼
½
¾
¿
×
Α
α
Β
β
Γ
γ
Δ
δ
Ε
ε
Ζ
ζ
Η
η
Θ
θ
Ι
ι
Κ
κ
Λ
λ
Μ
μ
Ν
ν
Ξ
ξ
Ο
ο
Π
π
Ρ
ρ
Σ
σ
ς
Τ
τ
Υ
υ
Φ
φ
Χ
χ
Ψ
ψ
ω
А
Б
В
Г
Ґ
Д
Е
Ё
Є
Ж
З
И
І
Ї
Й
К
Л
М
Н
О
П
Р
С
Т
У
Ў
Ф
Х
Ц
Ч
Ш
Щ
Ъ
Ы
Ь
Э
Ю
Я
а
б
в
г
ґ
д
е
ё
є
ж
з
и
і
ї
й
к
л
м
н
о
п
р
с
т
у
ў
ф
х
ц
ч
ш
щ
ъ
ы
ь
э
ю
я
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
\
]
^
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
{
|
}
~
¡
¢
£
¤
¥
¦
§
¨
©
ª
«
¬
­
®
¯
°
±
²
³
´
µ
·
¸
¹
º
»
¼
½
¾
¿
À
Á
Â
Ã
Ä
Å
Æ
Ç
È
É
Ê
Ë
Ì
Í
Î
Ï
Ð
Ñ
Ò
Ó
Ô
Õ
Ö
×
Ø
Ù
Ú
Û
Ü
Ý
Þ
ß
à
á
â
ã
ä
å
æ
ç
è
é
ê
ë
ì
í
î
ï
ð
ñ
ò
ó
ô
õ
ö
÷
ø
ù
ú
û
ü
ý
þ
ÿ
Ą
ą
Ć
ć
Č
č
Ď
ď
Đ
đ
Ė
ė
Ę
ę
Ě
ě
Ğ
ğ
Į
į
İ
ı
Ĺ
ĺ
Ľ
ľ
Ł
ł
Ń
ń
Ň
ň
ō
Ő
ő
Œ
œ
Ŕ
ŕ
Ř
ř
Ś
ś
Ş
ş
Š
š
Ť
ť
Ū
ū
Ů
ů
Ű
ű
Ų
ų
Ÿ
Ź
ź
Ż
ż
Ž
ž
ƒ
ʒ
Ω
α
β
γ
δ
ε
ζ
η
θ
ι
κ
λ
μ
ν
ξ
ο
π
ρ
ς
σ
τ
υ
φ
χ
ψ
ω
з
𝑢
𝜓
......@@ -24,17 +24,17 @@ lang:
rec: en_PP-OCRv4_rec_infer.pth
dict: en_dict.txt
korean:
det: Multilingual_PP-OCRv3_det_infer.pth
rec: korean_PP-OCRv3_rec_infer.pth
dict: korean_dict.txt
det: ch_PP-OCRv5_det_infer.pth
rec: korean_PP-OCRv5_rec_infer.pth
dict: ppocrv5_korean_dict.txt
japan:
det: ch_PP-OCRv5_det_infer.pth
rec: ch_PP-OCRv5_rec_server_infer.pth
dict: japan_dict.txt
dict: ppocrv5_dict.txt
chinese_cht:
det: ch_PP-OCRv5_det_infer.pth
rec: ch_PP-OCRv5_rec_server_infer.pth
dict: chinese_cht_dict.txt
dict: ppocrv5_dict.txt
ta:
det: Multilingual_PP-OCRv3_det_infer.pth
rec: ta_PP-OCRv3_rec_infer.pth
......@@ -48,9 +48,9 @@ lang:
rec: ka_PP-OCRv3_rec_infer.pth
dict: ka_dict.txt
latin:
det: en_PP-OCRv3_det_infer.pth
rec: latin_PP-OCRv3_rec_infer.pth
dict: latin_dict.txt
det: ch_PP-OCRv5_det_infer.pth
rec: latin_PP-OCRv5_rec_infer.pth
dict: ppocrv5_latin_dict.txt
arabic:
det: Multilingual_PP-OCRv3_det_infer.pth
rec: arabic_PP-OCRv3_rec_infer.pth
......@@ -62,4 +62,8 @@ lang:
devanagari:
det: Multilingual_PP-OCRv3_det_infer.pth
rec: devanagari_PP-OCRv3_rec_infer.pth
dict: devanagari_dict.txt
\ No newline at end of file
dict: devanagari_dict.txt
east_slavic:
det: ch_PP-OCRv5_det_infer.pth
rec: eslav_PP-OCRv5_rec_infer.pth
dict: ppocrv5_eslav_dict.txt
\ No newline at end of file
......@@ -62,7 +62,7 @@ class Mineru2QwenForCausalLM(nn.Module):
# load vision tower
mm_vision_tower = self.config.mm_vision_tower
model_root_path = auto_download_and_get_model_root_path("/", "vlm")
model_root_path = auto_download_and_get_model_root_path(mm_vision_tower, "vlm")
mm_vision_tower = f"{model_root_path}/{mm_vision_tower}"
if "clip" in mm_vision_tower:
......
......@@ -54,7 +54,7 @@
font-family: 'Trebuchet MS', 'Lucida Sans Unicode',
'Lucida Grande', 'Lucida Sans', Arial, sans-serif;
">
MinerU: PDF Extraction Demo
MinerU 2: PDF Extraction Demo
</h1>
</div>
</div>
......@@ -66,8 +66,7 @@
color: #fafafa;
opacity: 0.8;
">
A one-stop, open-source, high-quality data extraction tool, supports
PDF/webpage/e-book extraction.<br>
A one-stop, open-source, high-quality data extraction tool that supports converting PDF to Markdown and JSON.<br>
</p>
<style>
.link-block {
......
......@@ -90,8 +90,8 @@ def prepare_block_bboxes(
"""经过以上处理后,还存在大框套小框的情况,则删除小框"""
all_bboxes = remove_overlaps_min_blocks(all_bboxes)
all_discarded_blocks = remove_overlaps_min_blocks(all_discarded_blocks)
"""将剩余的bbox做分离处理,防止后面分layout时出错"""
# all_bboxes, drop_reasons = remove_overlap_between_bbox_for_block(all_bboxes)
"""粗排序后返回"""
all_bboxes.sort(key=lambda x: x[0]+x[1])
return all_bboxes, all_discarded_blocks, footnote_blocks
......@@ -213,35 +213,39 @@ def remove_overlaps_min_blocks(all_bboxes):
# 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。
# 删除重叠blocks中较小的那些
need_remove = []
for block1 in all_bboxes:
for block2 in all_bboxes:
if block1 != block2:
block1_bbox = block1[:4]
block2_bbox = block2[:4]
overlap_box = get_minbox_if_overlap_by_ratio(
block1_bbox, block2_bbox, 0.8
)
if overlap_box is not None:
block_to_remove = next(
(block for block in all_bboxes if block[:4] == overlap_box),
None,
)
if (
block_to_remove is not None
and block_to_remove not in need_remove
):
large_block = block1 if block1 != block_to_remove else block2
x1, y1, x2, y2 = large_block[:4]
sx1, sy1, sx2, sy2 = block_to_remove[:4]
x1 = min(x1, sx1)
y1 = min(y1, sy1)
x2 = max(x2, sx2)
y2 = max(y2, sy2)
large_block[:4] = [x1, y1, x2, y2]
need_remove.append(block_to_remove)
if len(need_remove) > 0:
for block in need_remove:
for i in range(len(all_bboxes)):
for j in range(i + 1, len(all_bboxes)):
block1 = all_bboxes[i]
block2 = all_bboxes[j]
block1_bbox = block1[:4]
block2_bbox = block2[:4]
overlap_box = get_minbox_if_overlap_by_ratio(
block1_bbox, block2_bbox, 0.8
)
if overlap_box is not None:
# 判断哪个区块的面积更小,移除较小的区块
area1 = (block1[2] - block1[0]) * (block1[3] - block1[1])
area2 = (block2[2] - block2[0]) * (block2[3] - block2[1])
if area1 <= area2:
block_to_remove = block1
large_block = block2
else:
block_to_remove = block2
large_block = block1
if block_to_remove not in need_remove:
x1, y1, x2, y2 = large_block[:4]
sx1, sy1, sx2, sy2 = block_to_remove[:4]
x1 = min(x1, sx1)
y1 = min(y1, sy1)
x2 = max(x2, sx2)
y2 = max(y2, sy2)
large_block[:4] = [x1, y1, x2, y2]
need_remove.append(block_to_remove)
for block in need_remove:
if block in all_bboxes:
all_bboxes.remove(block)
return all_bboxes
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment