Commit 7e6926ff authored by myhloli's avatar myhloli
Browse files

feat: enhance heading level feature with conditional imports and error handling

parent ca7a567e
import time
import cv2
import numpy as np
from loguru import logger
from mineru.backend.pipeline.model_init import AtomModelSingleton
import numpy as np
import cv2
from mineru.utils.config_reader import get_llm_aided_config
from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.enum_class import ContentType
from mineru.utils.hash_utils import str_md5
from mineru.backend.vlm.vlm_magic_model import MagicModel
from mineru.utils.llm_aided import llm_aided_title
from mineru.utils.pdf_image_tools import get_crop_img
from mineru.version import __version__
heading_level_import_success = False
try:
from mineru.utils.llm_aided import llm_aided_title
from mineru.backend.pipeline.model_init import AtomModelSingleton
heading_level_import_success = True
except Exception as e:
logger.warning("The heading level feature cannot be used. If you need to use the heading level feature, "
"please execute `pip install mineru[pipeline]` to install the required packages.")
def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dict:
"""将token转换为页面信息"""
......@@ -37,26 +43,27 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None:
if title_aided_config.get('enable', False):
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang='ch_lite'
)
for title_block in title_blocks:
title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
title_np_img = np.array(title_pil_img)
# 给title_pil_img添加上下左右各50像素白边padding
title_np_img = cv2.copyMakeBorder(
title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
if heading_level_import_success:
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang='ch_lite'
)
title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
if len(ocr_det_res) > 0:
# 计算所有res的平均高度
avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
title_block['line_avg_height'] = round(avg_height/scale)
for title_block in title_blocks:
title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
title_np_img = np.array(title_pil_img)
# 给title_pil_img添加上下左右各50像素白边padding
title_np_img = cv2.copyMakeBorder(
title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
)
title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
if len(ocr_det_res) > 0:
# 计算所有res的平均高度
avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
title_block['line_avg_height'] = round(avg_height/scale)
text_blocks = magic_model.get_text_blocks()
interline_equation_blocks = magic_model.get_interline_equation_blocks()
......@@ -86,15 +93,15 @@ def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
"""llm优化"""
llm_aided_config = get_llm_aided_config()
if llm_aided_config is not None:
"""标题优化"""
title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None:
if title_aided_config.get('enable', False):
llm_aided_title_start_time = time.time()
llm_aided_title(middle_json["pdf_info"], title_aided_config)
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
if heading_level_import_success:
llm_aided_title_start_time = time.time()
llm_aided_title(middle_json["pdf_info"], title_aided_config)
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
# 关闭pdf文档
pdf_doc.close()
......
......@@ -33,6 +33,8 @@ dependencies = [
"modelscope>=1.26.0",
"huggingface-hub>=0.32.4",
"json-repair>=0.46.2",
"opencv-python>=4.11.0.86",
"fast-langdetect>=0.2.3,<0.3.0",
]
[project.optional-dependencies]
......@@ -60,7 +62,6 @@ pipeline = [
"torch>=2.2.2,!=2.5.0,!=2.5.1,<3",
"torchvision",
"transformers>=4.49.0,!=4.51.0,<5.0.0",
"fast-langdetect>=0.2.3,<0.3.0",
]
api = [
"fastapi",
......@@ -97,7 +98,6 @@ pipeline_old_linux = [
"torch>=2.2.2,!=2.5.0,!=2.5.1,<3",
"torchvision",
"transformers>=4.49.0,!=4.51.0,<5.0.0",
"fast-langdetect>=0.2.3,<0.3.0",
]
[project.urls]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment