Commit 7d4ce0c3 authored by myhloli's avatar myhloli
Browse files

refactor: add LLM-aided title optimization and improve config handling

parent d2de6d80
...@@ -124,4 +124,14 @@ def get_latex_delimiter_config(): ...@@ -124,4 +124,14 @@ def get_latex_delimiter_config():
logger.warning(f"'latex-delimiter-config' not found in {CONFIG_FILE_NAME}, use 'None' as default") logger.warning(f"'latex-delimiter-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
return None return None
else: else:
return latex_delimiter_config return latex_delimiter_config
\ No newline at end of file
def get_llm_aided_config():
config = read_config()
llm_aided_config = config.get('llm-aided-config')
if llm_aided_config is None:
logger.warning(f"'llm-aided-config' not found in {CONFIG_FILE_NAME}, use 'None' as default")
return None
else:
return llm_aided_config
\ No newline at end of file
# Copyright (c) Opendatalab. All rights reserved. # Copyright (c) Opendatalab. All rights reserved.
from mineru.backend.pipeline.config_reader import get_device import time
from loguru import logger
from mineru.backend.pipeline.config_reader import get_device, get_llm_aided_config
from mineru.backend.pipeline.model_init import AtomModelSingleton from mineru.backend.pipeline.model_init import AtomModelSingleton
from mineru.backend.pipeline.para_split import para_split from mineru.backend.pipeline.para_split import para_split
from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups
from mineru.utils.block_sort import sort_blocks_by_bbox from mineru.utils.block_sort import sort_blocks_by_bbox
from mineru.utils.cut_image import cut_image_and_table from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.llm_aided import llm_aided_title
from mineru.utils.model_utils import clean_memory from mineru.utils.model_utils import clean_memory
from mineru.utils.pipeline_magic_model import MagicModel from mineru.utils.pipeline_magic_model import MagicModel
from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
...@@ -169,6 +174,18 @@ def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=N ...@@ -169,6 +174,18 @@ def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=N
"""分段""" """分段"""
para_split(middle_json["pdf_info"]) para_split(middle_json["pdf_info"])
"""llm优化"""
llm_aided_config = get_llm_aided_config()
if llm_aided_config is not None:
"""标题优化"""
title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None:
if title_aided_config.get('enable', False):
llm_aided_title_start_time = time.time()
llm_aided_title(middle_json["pdf_info"], title_aided_config)
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
clean_memory(get_device()) clean_memory(get_device())
return middle_json return middle_json
......
...@@ -215,8 +215,7 @@ def do_parse( ...@@ -215,8 +215,7 @@ def do_parse(
if __name__ == "__main__": if __name__ == "__main__":
# pdf_path = "../../demo/pdfs/计算机学报-单词中间有换行符-span不准确.pdf" pdf_path = "../../demo/pdfs/demo2.pdf"
pdf_path = "../../demo/pdfs/demo1.pdf"
with open(pdf_path, "rb") as f: with open(pdf_path, "rb") as f:
try: try:
do_parse("./output", [Path(pdf_path).stem], [f.read()],["ch"], end_page_id=20,) do_parse("./output", [Path(pdf_path).stem], [f.read()],["ch"], end_page_id=20,)
......
# Copyright (c) Opendatalab. All rights reserved.
from loguru import logger
from openai import OpenAI
import ast
from mineru.api.pipeline_middle_json_mkcontent import merge_para_with_text
def llm_aided_title(page_info_list, title_aided_config):
client = OpenAI(
api_key=title_aided_config["api_key"],
base_url=title_aided_config["base_url"],
)
title_dict = {}
origin_title_list = []
i = 0
for page_info in page_info_list:
blocks = page_info["para_blocks"]
for block in blocks:
if block["type"] == "title":
origin_title_list.append(block)
title_text = merge_para_with_text(block)
page_line_height_list = []
for line in block['lines']:
bbox = line['bbox']
page_line_height_list.append(int(bbox[3] - bbox[1]))
if len(page_line_height_list) > 0:
line_avg_height = sum(page_line_height_list) / len(page_line_height_list)
else:
line_avg_height = int(block['bbox'][3] - block['bbox'][1])
title_dict[f"{i}"] = [title_text, line_avg_height, int(page_info['page_idx']) + 1]
i += 1
# logger.info(f"Title list: {title_dict}")
title_optimize_prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构:
1. 字典中每个value均为一个list,包含以下元素:
- 标题文本
- 文本行高是标题所在块的平均行高
- 标题所在的页码
2. 保留原始内容:
- 输入的字典中所有元素都是有效的,不能删除字典中的任何元素
- 请务必保证输出的字典中元素的数量和输入的数量一致
3. 保持字典内key-value的对应关系不变
4. 优化层次结构:
- 为每个标题元素添加适当的层次结构
- 行高较大的标题一般是更高级别的标题
- 标题从前至后的层级必须是连续的,不能跳过层级
- 标题层级最多为4级,不要添加过多的层级
- 优化后的标题只保留代表该标题的层级的整数,不要保留其他信息
5. 合理性检查与微调:
- 在完成初步分级后,仔细检查分级结果的合理性
- 根据上下文关系和逻辑顺序,对不合理的分级进行微调
- 确保最终的分级结果符合文档的实际结构和逻辑
- 字典中可能包含被误当成标题的正文,你可以通过将其层级标记为 0 来排除它们
IMPORTANT:
请直接返回优化过的由标题层级组成的字典,格式为{{标题id:标题层级}},如下:
{{0:1,1:2,2:2,3:3}}
不需要对字典格式化,不需要返回任何其他信息。
Input title list:
{title_dict}
Corrected title list:
"""
retry_count = 0
max_retries = 3
dict_completion = None
while retry_count < max_retries:
try:
completion = client.chat.completions.create(
model=title_aided_config["model"],
messages=[
{'role': 'user', 'content': title_optimize_prompt}],
temperature=0.7,
)
# logger.info(f"Title completion: {completion.choices[0].message.content}")
dict_completion = ast.literal_eval(completion.choices[0].message.content)
# logger.info(f"len(dict_completion): {len(dict_completion)}, len(title_dict): {len(title_dict)}")
if len(dict_completion) == len(title_dict):
for i, origin_title_block in enumerate(origin_title_list):
origin_title_block["level"] = int(dict_completion[i])
break
else:
logger.warning(
"The number of titles in the optimized result is not equal to the number of titles in the input.")
retry_count += 1
except Exception as e:
logger.exception(e)
retry_count += 1
if dict_completion is None:
logger.error("Failed to decode dict after maximum retries.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment