Unverified Commit 82eb7473 authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2989 from myhloli/dev

feat: improve heading level feature with enhanced configuration and error handling
parents 59b4dd19 2d742bca
...@@ -11,13 +11,18 @@ from mineru.utils.pdf_image_tools import get_crop_img ...@@ -11,13 +11,18 @@ from mineru.utils.pdf_image_tools import get_crop_img
from mineru.version import __version__ from mineru.version import __version__
heading_level_import_success = False heading_level_import_success = False
try: llm_aided_config = get_llm_aided_config()
from mineru.utils.llm_aided import llm_aided_title if llm_aided_config is not None:
from mineru.backend.pipeline.model_init import AtomModelSingleton title_aided_config = llm_aided_config.get('title_aided', None)
heading_level_import_success = True if title_aided_config is not None:
except Exception as e: if title_aided_config.get('enable', False):
logger.warning("The heading level feature cannot be used. If you need to use the heading level feature, " try:
"please execute `pip install mineru[pipeline]` to install the required packages.") from mineru.utils.llm_aided import llm_aided_title
from mineru.backend.pipeline.model_init import AtomModelSingleton
heading_level_import_success = True
except Exception as e:
logger.warning("The heading level feature cannot be used. If you need to use the heading level feature, "
"please execute `pip install mineru[core]` to install the required packages.")
def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dict: def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dict:
...@@ -38,32 +43,27 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic ...@@ -38,32 +43,27 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
title_blocks = magic_model.get_title_blocks() title_blocks = magic_model.get_title_blocks()
# 如果有标题优化需求,则对title_blocks截图det # 如果有标题优化需求,则对title_blocks截图det
llm_aided_config = get_llm_aided_config() if heading_level_import_success:
if llm_aided_config is not None: atom_model_manager = AtomModelSingleton()
title_aided_config = llm_aided_config.get('title_aided', None) ocr_model = atom_model_manager.get_atom_model(
if title_aided_config is not None: atom_model_name='ocr',
if title_aided_config.get('enable', False): ocr_show_log=False,
if heading_level_import_success: det_db_box_thresh=0.3,
atom_model_manager = AtomModelSingleton() lang='ch_lite'
ocr_model = atom_model_manager.get_atom_model( )
atom_model_name='ocr', for title_block in title_blocks:
ocr_show_log=False, title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
det_db_box_thresh=0.3, title_np_img = np.array(title_pil_img)
lang='ch_lite' # 给title_pil_img添加上下左右各50像素白边padding
) title_np_img = cv2.copyMakeBorder(
for title_block in title_blocks: title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale) )
title_np_img = np.array(title_pil_img) title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
# 给title_pil_img添加上下左右各50像素白边padding ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
title_np_img = cv2.copyMakeBorder( if len(ocr_det_res) > 0:
title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255] # 计算所有res的平均高度
) avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR) title_block['line_avg_height'] = round(avg_height/scale)
ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
if len(ocr_det_res) > 0:
# 计算所有res的平均高度
avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
title_block['line_avg_height'] = round(avg_height/scale)
text_blocks = magic_model.get_text_blocks() text_blocks = magic_model.get_text_blocks()
interline_equation_blocks = magic_model.get_interline_equation_blocks() interline_equation_blocks = magic_model.get_interline_equation_blocks()
...@@ -91,17 +91,11 @@ def result_to_middle_json(token_list, images_list, pdf_doc, image_writer): ...@@ -91,17 +91,11 @@ def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
page_info = token_to_page_info(token, image_dict, page, image_writer, index) page_info = token_to_page_info(token, image_dict, page, image_writer, index)
middle_json["pdf_info"].append(page_info) middle_json["pdf_info"].append(page_info)
"""llm优化""" """llm优化标题分级"""
llm_aided_config = get_llm_aided_config() if heading_level_import_success:
if llm_aided_config is not None: llm_aided_title_start_time = time.time()
"""标题优化""" llm_aided_title(middle_json["pdf_info"], title_aided_config)
title_aided_config = llm_aided_config.get('title_aided', None) logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
if title_aided_config is not None:
if title_aided_config.get('enable', False):
if heading_level_import_success:
llm_aided_title_start_time = time.time()
llm_aided_title(middle_json["pdf_info"], title_aided_config)
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
# 关闭pdf文档 # 关闭pdf文档
pdf_doc.close() pdf_doc.close()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment