"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5c94937dc7561767892d711e199f874dc35df041"
Unverified Commit 33bea910 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2948 from myhloli/dev

feat: enhance heading level feature with conditional imports and error handling
parents dd4b60f1 374f464b
import time import time
import cv2
import numpy as np
from loguru import logger from loguru import logger
import numpy as np
from mineru.backend.pipeline.model_init import AtomModelSingleton import cv2
from mineru.utils.config_reader import get_llm_aided_config from mineru.utils.config_reader import get_llm_aided_config
from mineru.utils.cut_image import cut_image_and_table from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.enum_class import ContentType from mineru.utils.enum_class import ContentType
from mineru.utils.hash_utils import str_md5 from mineru.utils.hash_utils import str_md5
from mineru.backend.vlm.vlm_magic_model import MagicModel from mineru.backend.vlm.vlm_magic_model import MagicModel
from mineru.utils.llm_aided import llm_aided_title
from mineru.utils.pdf_image_tools import get_crop_img from mineru.utils.pdf_image_tools import get_crop_img
from mineru.version import __version__ from mineru.version import __version__
heading_level_import_success = False
try:
from mineru.utils.llm_aided import llm_aided_title
from mineru.backend.pipeline.model_init import AtomModelSingleton
heading_level_import_success = True
except Exception as e:
logger.warning("The heading level feature cannot be used. If you need to use the heading level feature, "
"please execute `pip install mineru[pipeline]` to install the required packages.")
def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dict: def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dict:
"""将token转换为页面信息""" """将token转换为页面信息"""
...@@ -37,26 +43,27 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic ...@@ -37,26 +43,27 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
title_aided_config = llm_aided_config.get('title_aided', None) title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None: if title_aided_config is not None:
if title_aided_config.get('enable', False): if title_aided_config.get('enable', False):
atom_model_manager = AtomModelSingleton() if heading_level_import_success:
ocr_model = atom_model_manager.get_atom_model( atom_model_manager = AtomModelSingleton()
atom_model_name='ocr', ocr_model = atom_model_manager.get_atom_model(
ocr_show_log=False, atom_model_name='ocr',
det_db_box_thresh=0.3, ocr_show_log=False,
lang='ch_lite' det_db_box_thresh=0.3,
) lang='ch_lite'
for title_block in title_blocks:
title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
title_np_img = np.array(title_pil_img)
# 给title_pil_img添加上下左右各50像素白边padding
title_np_img = cv2.copyMakeBorder(
title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
) )
title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR) for title_block in title_blocks:
ocr_det_res = ocr_model.ocr(title_img, rec=False)[0] title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
if len(ocr_det_res) > 0: title_np_img = np.array(title_pil_img)
# 计算所有res的平均高度 # 给title_pil_img添加上下左右各50像素白边padding
avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res]) title_np_img = cv2.copyMakeBorder(
title_block['line_avg_height'] = round(avg_height/scale) title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
)
title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
if len(ocr_det_res) > 0:
# 计算所有res的平均高度
avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
title_block['line_avg_height'] = round(avg_height/scale)
text_blocks = magic_model.get_text_blocks() text_blocks = magic_model.get_text_blocks()
interline_equation_blocks = magic_model.get_interline_equation_blocks() interline_equation_blocks = magic_model.get_interline_equation_blocks()
...@@ -86,15 +93,15 @@ def result_to_middle_json(token_list, images_list, pdf_doc, image_writer): ...@@ -86,15 +93,15 @@ def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
"""llm优化""" """llm优化"""
llm_aided_config = get_llm_aided_config() llm_aided_config = get_llm_aided_config()
if llm_aided_config is not None: if llm_aided_config is not None:
"""标题优化""" """标题优化"""
title_aided_config = llm_aided_config.get('title_aided', None) title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None: if title_aided_config is not None:
if title_aided_config.get('enable', False): if title_aided_config.get('enable', False):
llm_aided_title_start_time = time.time() if heading_level_import_success:
llm_aided_title(middle_json["pdf_info"], title_aided_config) llm_aided_title_start_time = time.time()
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}') llm_aided_title(middle_json["pdf_info"], title_aided_config)
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
# 关闭pdf文档 # 关闭pdf文档
pdf_doc.close() pdf_doc.close()
......
...@@ -114,12 +114,15 @@ async def to_markdown(file_path, end_pages=10, is_ocr=False, formula_enable=True ...@@ -114,12 +114,15 @@ async def to_markdown(file_path, end_pages=10, is_ocr=False, formula_enable=True
return md_content, txt_content, archive_zip_path, new_pdf_path return md_content, txt_content, archive_zip_path, new_pdf_path
latex_delimiters = [ latex_delimiters_type_a = [
{'left': '$$', 'right': '$$', 'display': True}, {'left': '$$', 'right': '$$', 'display': True},
{'left': '$', 'right': '$', 'display': False}, {'left': '$', 'right': '$', 'display': False},
]
latex_delimiters_type_b = [
{'left': '\\(', 'right': '\\)', 'display': False}, {'left': '\\(', 'right': '\\)', 'display': False},
{'left': '\\[', 'right': '\\]', 'display': True}, {'left': '\\[', 'right': '\\]', 'display': True},
] ]
latex_delimiters_type_all = latex_delimiters_type_a + latex_delimiters_type_b
header_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'resources', 'header.html') header_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'resources', 'header.html')
with open(header_path, 'r') as header_file: with open(header_path, 'r') as header_file:
...@@ -234,13 +237,30 @@ def update_interface(backend_choice): ...@@ -234,13 +237,30 @@ def update_interface(backend_choice):
help="Set the server port for the Gradio app.", help="Set the server port for the Gradio app.",
default=None, default=None,
) )
@click.option(
'--latex-delimiters-type',
'latex_delimiters_type',
type=click.Choice(['a', 'b', 'all']),
help="Set the type of LaTeX delimiters to use in Markdown rendering:"
"'a' for type '$', 'b' for type '()[]', 'all' for both types.",
default='all',
)
def main(ctx, def main(ctx,
example_enable, sglang_engine_enable, api_enable, max_convert_pages, example_enable, sglang_engine_enable, api_enable, max_convert_pages,
server_name, server_port, **kwargs server_name, server_port, latex_delimiters_type, **kwargs
): ):
kwargs.update(arg_parse(ctx)) kwargs.update(arg_parse(ctx))
if latex_delimiters_type == 'a':
latex_delimiters = latex_delimiters_type_a
elif latex_delimiters_type == 'b':
latex_delimiters = latex_delimiters_type_b
elif latex_delimiters_type == 'all':
latex_delimiters = latex_delimiters_type_all
else:
raise ValueError(f"Invalid latex delimiters type: {latex_delimiters_type}.")
if sglang_engine_enable: if sglang_engine_enable:
try: try:
print("Start init SgLang engine...") print("Start init SgLang engine...")
......
...@@ -33,6 +33,8 @@ dependencies = [ ...@@ -33,6 +33,8 @@ dependencies = [
"modelscope>=1.26.0", "modelscope>=1.26.0",
"huggingface-hub>=0.32.4", "huggingface-hub>=0.32.4",
"json-repair>=0.46.2", "json-repair>=0.46.2",
"opencv-python>=4.11.0.86",
"fast-langdetect>=0.2.3,<0.3.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]
...@@ -60,7 +62,6 @@ pipeline = [ ...@@ -60,7 +62,6 @@ pipeline = [
"torch>=2.2.2,!=2.5.0,!=2.5.1,<3", "torch>=2.2.2,!=2.5.0,!=2.5.1,<3",
"torchvision", "torchvision",
"transformers>=4.49.0,!=4.51.0,<5.0.0", "transformers>=4.49.0,!=4.51.0,<5.0.0",
"fast-langdetect>=0.2.3,<0.3.0",
] ]
api = [ api = [
"fastapi", "fastapi",
...@@ -97,7 +98,6 @@ pipeline_old_linux = [ ...@@ -97,7 +98,6 @@ pipeline_old_linux = [
"torch>=2.2.2,!=2.5.0,!=2.5.1,<3", "torch>=2.2.2,!=2.5.0,!=2.5.1,<3",
"torchvision", "torchvision",
"transformers>=4.49.0,!=4.51.0,<5.0.0", "transformers>=4.49.0,!=4.51.0,<5.0.0",
"fast-langdetect>=0.2.3,<0.3.0",
] ]
[project.urls] [project.urls]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment