Unverified Commit 21bd73ea authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2850 from myhloli/dev

Dev
parents 7d6b1062 7c95c62d
...@@ -323,7 +323,7 @@ class BatchAnalyze: ...@@ -323,7 +323,7 @@ class BatchAnalyze:
layout_res_item['poly'][4], layout_res_item['poly'][5]] layout_res_item['poly'][4], layout_res_item['poly'][5]]
layout_res_width = layout_res_bbox[2] - layout_res_bbox[0] layout_res_width = layout_res_bbox[2] - layout_res_bbox[0]
layout_res_height = layout_res_bbox[3] - layout_res_bbox[1] layout_res_height = layout_res_bbox[3] - layout_res_bbox[1]
if ocr_text in ['(204号', '(20', '(2', '(2号'] and ocr_score < 0.8 and layout_res_width < layout_res_height: if ocr_text in ['(204号', '(20', '(2', '(2号', '(20号'] and ocr_score < 0.8 and layout_res_width < layout_res_height:
layout_res_item['category_id'] = 16 layout_res_item['category_id'] = 16
total_processed += len(img_crop_list) total_processed += len(img_crop_list)
......
import re import time
import cv2
import numpy as np
from loguru import logger
from mineru.backend.pipeline.model_init import AtomModelSingleton
from mineru.utils.config_reader import get_llm_aided_config
from mineru.utils.cut_image import cut_image_and_table from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.enum_class import BlockType, ContentType from mineru.utils.enum_class import ContentType
from mineru.utils.hash_utils import str_md5 from mineru.utils.hash_utils import str_md5
from mineru.backend.vlm.vlm_magic_model import MagicModel from mineru.backend.vlm.vlm_magic_model import MagicModel
from mineru.utils.llm_aided import llm_aided_title
from mineru.utils.pdf_image_tools import get_crop_img
from mineru.version import __version__ from mineru.version import __version__
...@@ -23,6 +30,34 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic ...@@ -23,6 +30,34 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
image_blocks = magic_model.get_image_blocks() image_blocks = magic_model.get_image_blocks()
table_blocks = magic_model.get_table_blocks() table_blocks = magic_model.get_table_blocks()
title_blocks = magic_model.get_title_blocks() title_blocks = magic_model.get_title_blocks()
# 如果有标题优化需求,则对title_blocks截图det
llm_aided_config = get_llm_aided_config()
if llm_aided_config is not None:
title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None:
if title_aided_config.get('enable', False):
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang='ch_lite'
)
for title_block in title_blocks:
title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
title_np_img = np.array(title_pil_img)
# 给title_pil_img添加上下左右各50像素白边padding
title_np_img = cv2.copyMakeBorder(
title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
)
title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
if len(ocr_det_res) > 0:
# 计算所有res的平均高度
avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
title_block['line_avg_height'] = round(avg_height/scale)
text_blocks = magic_model.get_text_blocks() text_blocks = magic_model.get_text_blocks()
interline_equation_blocks = magic_model.get_interline_equation_blocks() interline_equation_blocks = magic_model.get_interline_equation_blocks()
...@@ -48,6 +83,19 @@ def result_to_middle_json(token_list, images_list, pdf_doc, image_writer): ...@@ -48,6 +83,19 @@ def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
image_dict = images_list[index] image_dict = images_list[index]
page_info = token_to_page_info(token, image_dict, page, image_writer, index) page_info = token_to_page_info(token, image_dict, page, image_writer, index)
middle_json["pdf_info"].append(page_info) middle_json["pdf_info"].append(page_info)
"""llm优化"""
llm_aided_config = get_llm_aided_config()
if llm_aided_config is not None:
"""标题优化"""
title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None:
if title_aided_config.get('enable', False):
llm_aided_title_start_time = time.time()
llm_aided_title(middle_json["pdf_info"], title_aided_config)
logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
# 关闭pdf文档 # 关闭pdf文档
pdf_doc.close() pdf_doc.close()
return middle_json return middle_json
......
...@@ -209,14 +209,14 @@ def update_interface(backend_choice): ...@@ -209,14 +209,14 @@ def update_interface(backend_choice):
'mem_fraction_static', 'mem_fraction_static',
type=float, type=float,
help="Set the static memory fraction for SgLang engine. ", help="Set the static memory fraction for SgLang engine. ",
default=0.5, default=None,
) )
@click.option( @click.option(
'--enable-torch-compile', '--enable-torch-compile',
'torch_compile_enable', 'torch_compile_enable',
type=bool, type=bool,
help="Enable torch compile for SgLang engine. ", help="Enable torch compile for SgLang engine. ",
default=True, default=False,
) )
@click.option( @click.option(
'--enable-api', '--enable-api',
...@@ -231,12 +231,19 @@ def main(example_enable, sglang_engine_enable, mem_fraction_static, torch_compil ...@@ -231,12 +231,19 @@ def main(example_enable, sglang_engine_enable, mem_fraction_static, torch_compil
print("Start init SgLang engine...") print("Start init SgLang engine...")
from mineru.backend.vlm.vlm_analyze import ModelSingleton from mineru.backend.vlm.vlm_analyze import ModelSingleton
modelsingleton = ModelSingleton() modelsingleton = ModelSingleton()
model_params = {
"enable_torch_compile": torch_compile_enable
}
# 只有当mem_fraction_static不为None时才添加该参数
if mem_fraction_static is not None:
model_params["mem_fraction_static"] = mem_fraction_static
predictor = modelsingleton.get_model( predictor = modelsingleton.get_model(
"sglang-engine", "sglang-engine",
None, None,
None, None,
mem_fraction_static=mem_fraction_static, **model_params
enable_torch_compile=torch_compile_enable,
) )
print("SgLang engine init successfully.") print("SgLang engine init successfully.")
except Exception as e: except Exception as e:
......
# Copyright (c) Opendatalab. All rights reserved. # Copyright (c) Opendatalab. All rights reserved.
from loguru import logger from loguru import logger
from openai import OpenAI from openai import OpenAI
import ast import json_repair
from mineru.backend.pipeline.pipeline_middle_json_mkcontent import merge_para_with_text from mineru.backend.pipeline.pipeline_middle_json_mkcontent import merge_para_with_text
...@@ -20,14 +20,19 @@ def llm_aided_title(page_info_list, title_aided_config): ...@@ -20,14 +20,19 @@ def llm_aided_title(page_info_list, title_aided_config):
if block["type"] == "title": if block["type"] == "title":
origin_title_list.append(block) origin_title_list.append(block)
title_text = merge_para_with_text(block) title_text = merge_para_with_text(block)
page_line_height_list = []
for line in block['lines']: if 'line_avg_height' in block:
bbox = line['bbox'] line_avg_height = block['line_avg_height']
page_line_height_list.append(int(bbox[3] - bbox[1]))
if len(page_line_height_list) > 0:
line_avg_height = sum(page_line_height_list) / len(page_line_height_list)
else: else:
line_avg_height = int(block['bbox'][3] - block['bbox'][1]) title_block_line_height_list = []
for line in block['lines']:
bbox = line['bbox']
title_block_line_height_list.append(int(bbox[3] - bbox[1]))
if len(title_block_line_height_list) > 0:
line_avg_height = sum(title_block_line_height_list) / len(title_block_line_height_list)
else:
line_avg_height = int(block['bbox'][3] - block['bbox'][1])
title_dict[f"{i}"] = [title_text, line_avg_height, int(page_info['page_idx']) + 1] title_dict[f"{i}"] = [title_text, line_avg_height, int(page_info['page_idx']) + 1]
i += 1 i += 1
# logger.info(f"Title list: {title_dict}") # logger.info(f"Title list: {title_dict}")
...@@ -91,7 +96,6 @@ Corrected title list: ...@@ -91,7 +96,6 @@ Corrected title list:
if "</think>" in content: if "</think>" in content:
idx = content.index("</think>") + len("</think>") idx = content.index("</think>") + len("</think>")
content = content[idx:].strip() content = content[idx:].strip()
import json_repair
dict_completion = json_repair.loads(content) dict_completion = json_repair.loads(content)
dict_completion = {int(k): int(v) for k, v in dict_completion.items()} dict_completion = {int(k): int(v) for k, v in dict_completion.items()}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment