Commit 06db3d17 authored by myhloli's avatar myhloli
Browse files

feat: enhance title block processing with average height calculation and padding for OCR

parent 35cb414f
import time import time
import cv2
import numpy as np
from loguru import logger from loguru import logger
from mineru.backend.pipeline.model_init import AtomModelSingleton
from mineru.utils.config_reader import get_llm_aided_config from mineru.utils.config_reader import get_llm_aided_config
from mineru.utils.cut_image import cut_image_and_table from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.enum_class import BlockType, ContentType from mineru.utils.enum_class import ContentType
from mineru.utils.hash_utils import str_md5 from mineru.utils.hash_utils import str_md5
from mineru.backend.vlm.vlm_magic_model import MagicModel from mineru.backend.vlm.vlm_magic_model import MagicModel
from mineru.utils.llm_aided import llm_aided_title from mineru.utils.llm_aided import llm_aided_title
from mineru.utils.pdf_image_tools import get_crop_img
from mineru.version import __version__ from mineru.version import __version__
...@@ -26,6 +30,34 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic ...@@ -26,6 +30,34 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
image_blocks = magic_model.get_image_blocks() image_blocks = magic_model.get_image_blocks()
table_blocks = magic_model.get_table_blocks() table_blocks = magic_model.get_table_blocks()
title_blocks = magic_model.get_title_blocks() title_blocks = magic_model.get_title_blocks()
# 如果有标题优化需求,则对title_blocks截图det
llm_aided_config = get_llm_aided_config()
if llm_aided_config is not None:
title_aided_config = llm_aided_config.get('title_aided', None)
if title_aided_config is not None:
if title_aided_config.get('enable', False):
atom_model_manager = AtomModelSingleton()
ocr_model = atom_model_manager.get_atom_model(
atom_model_name='ocr',
ocr_show_log=False,
det_db_box_thresh=0.3,
lang='ch_lite'
)
for title_block in title_blocks:
title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
title_np_img = np.array(title_pil_img)
# 给title_pil_img添加上下左右各50像素白边padding
title_np_img = cv2.copyMakeBorder(
title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
)
title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
if len(ocr_det_res) > 0:
# 计算所有res的平均高度
avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
title_block['line_avg_height'] = round(avg_height/scale)
text_blocks = magic_model.get_text_blocks() text_blocks = magic_model.get_text_blocks()
interline_equation_blocks = magic_model.get_interline_equation_blocks() interline_equation_blocks = magic_model.get_interline_equation_blocks()
......
...@@ -20,14 +20,19 @@ def llm_aided_title(page_info_list, title_aided_config): ...@@ -20,14 +20,19 @@ def llm_aided_title(page_info_list, title_aided_config):
if block["type"] == "title": if block["type"] == "title":
origin_title_list.append(block) origin_title_list.append(block)
title_text = merge_para_with_text(block) title_text = merge_para_with_text(block)
page_line_height_list = []
for line in block['lines']: if 'line_avg_height' in block:
bbox = line['bbox'] line_avg_height = block['line_avg_height']
page_line_height_list.append(int(bbox[3] - bbox[1]))
if len(page_line_height_list) > 0:
line_avg_height = sum(page_line_height_list) / len(page_line_height_list)
else: else:
line_avg_height = int(block['bbox'][3] - block['bbox'][1]) title_block_line_height_list = []
for line in block['lines']:
bbox = line['bbox']
title_block_line_height_list.append(int(bbox[3] - bbox[1]))
if len(title_block_line_height_list) > 0:
line_avg_height = sum(title_block_line_height_list) / len(title_block_line_height_list)
else:
line_avg_height = int(block['bbox'][3] - block['bbox'][1])
title_dict[f"{i}"] = [title_text, line_avg_height, int(page_info['page_idx']) + 1] title_dict[f"{i}"] = [title_text, line_avg_height, int(page_info['page_idx']) + 1]
i += 1 i += 1
# logger.info(f"Title list: {title_dict}") # logger.info(f"Title list: {title_dict}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment