Commit 236a6033 authored by myhloli's avatar myhloli
Browse files

refactor: improve block processing logic and enhance span handling

parent 6f2c3ad8
...@@ -230,18 +230,18 @@ class BatchAnalyze: ...@@ -230,18 +230,18 @@ class BatchAnalyze:
ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],
new_image, _lang) new_image, _lang)
if res["category_id"] == 3: # if res["category_id"] == 3 and ocr_res_list_dict['ocr_enable']:
# ocr_result_list中所有bbox的面积之和 # # ocr_result_list中所有bbox的面积之和
ocr_res_area = sum( # ocr_res_area = sum(
get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item) # get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
# 求ocr_res_area和res的面积的比值 # # 求ocr_res_area和res的面积的比值
res_area = get_coords_and_area(res)[4] # res_area = get_coords_and_area(res)[4]
if res_area > 0: # if res_area > 0:
ratio = ocr_res_area / res_area # ratio = ocr_res_area / res_area
if ratio > 0.3: # if ratio > 0.25:
res["category_id"] = 1 # res["category_id"] = 1
else: # else:
continue # continue
ocr_res_list_dict['layout_res'].extend(ocr_result_list) ocr_res_list_dict['layout_res'].extend(ocr_result_list)
...@@ -321,6 +321,8 @@ class BatchAnalyze: ...@@ -321,6 +321,8 @@ class BatchAnalyze:
ocr_text, ocr_score = ocr_res_list[index] ocr_text, ocr_score = ocr_res_list[index]
layout_res_item['text'] = ocr_text layout_res_item['text'] = ocr_text
layout_res_item['score'] = float(f"{ocr_score:.3f}") layout_res_item['score'] = float(f"{ocr_score:.3f}")
if ocr_score < 0.6:
layout_res_item['category_id'] = 16
total_processed += len(img_crop_list) total_processed += len(img_crop_list)
......
...@@ -8,6 +8,7 @@ from mineru.backend.pipeline.model_init import AtomModelSingleton ...@@ -8,6 +8,7 @@ from mineru.backend.pipeline.model_init import AtomModelSingleton
from mineru.backend.pipeline.para_split import para_split from mineru.backend.pipeline.para_split import para_split
from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups from mineru.utils.block_pre_proc import prepare_block_bboxes, process_groups
from mineru.utils.block_sort import sort_blocks_by_bbox from mineru.utils.block_sort import sort_blocks_by_bbox
from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
from mineru.utils.cut_image import cut_image_and_table from mineru.utils.cut_image import cut_image_and_table
from mineru.utils.llm_aided import llm_aided_title from mineru.utils.llm_aided import llm_aided_title
from mineru.utils.model_utils import clean_memory from mineru.utils.model_utils import clean_memory
...@@ -27,22 +28,48 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer ...@@ -27,22 +28,48 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
magic_model = MagicModel(page_model_info, scale) magic_model = MagicModel(page_model_info, scale)
"""从magic_model对象中获取后面会用到的区块信息""" """从magic_model对象中获取后面会用到的区块信息"""
discarded_blocks = magic_model.get_discarded()
text_blocks = magic_model.get_text_blocks()
title_blocks = magic_model.get_title_blocks()
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations()
img_groups = magic_model.get_imgs() img_groups = magic_model.get_imgs()
table_groups = magic_model.get_tables() table_groups = magic_model.get_tables()
"""对image和table的区块分组""" """对image和table的区块分组"""
img_body_blocks, img_caption_blocks, img_footnote_blocks = process_groups( img_body_blocks, img_caption_blocks, img_footnote_blocks, maybe_text_image_blocks = process_groups(
img_groups, 'image_body', 'image_caption_list', 'image_footnote_list' img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
) )
table_body_blocks, table_caption_blocks, table_footnote_blocks = process_groups( table_body_blocks, table_caption_blocks, table_footnote_blocks, _ = process_groups(
table_groups, 'table_body', 'table_caption_list', 'table_footnote_list' table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
) )
discarded_blocks = magic_model.get_discarded() """获取所有的spans信息"""
text_blocks = magic_model.get_text_blocks() spans = magic_model.get_all_spans()
title_blocks = magic_model.get_title_blocks()
inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations() if len(maybe_text_image_blocks) > 0:
for block in maybe_text_image_blocks:
span_in_block_list = []
for span in spans:
if span['type'] == 'text' and calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block['bbox']) > 0.7:
span_in_block_list.append(span)
if len(span_in_block_list) > 0:
# span_in_block_list中所有bbox的面积之和
spans_area = sum((span['bbox'][2] - span['bbox'][0]) * (span['bbox'][3] - span['bbox'][1]) for span in span_in_block_list)
# 求ocr_res_area和res的面积的比值
block_area = (block['bbox'][2] - block['bbox'][0]) * (block['bbox'][3] - block['bbox'][1])
if block_area > 0:
ratio = spans_area / block_area
if ratio > 0.25 and ocr:
# 移除block的group_id
block.pop('group_id', None)
text_blocks.append(block)
else:
img_body_blocks.append(block)
else:
img_body_blocks.append(block)
"""将所有区块的bbox整理到一起""" """将所有区块的bbox整理到一起"""
interline_equation_blocks = [] interline_equation_blocks = []
...@@ -68,8 +95,7 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer ...@@ -68,8 +95,7 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
page_w, page_w,
page_h, page_h,
) )
"""获取所有的spans信息"""
spans = magic_model.get_all_spans()
"""在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span""" """在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
"""顺便删除大水印并保留abandon的span""" """顺便删除大水印并保留abandon的span"""
spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks) spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
......
...@@ -12,7 +12,14 @@ def process_groups(groups, body_key, caption_key, footnote_key): ...@@ -12,7 +12,14 @@ def process_groups(groups, body_key, caption_key, footnote_key):
body_blocks = [] body_blocks = []
caption_blocks = [] caption_blocks = []
footnote_blocks = [] footnote_blocks = []
maybe_text_image_blocks = []
for i, group in enumerate(groups): for i, group in enumerate(groups):
if body_key == 'image_body' and len(group[caption_key]) == 0 and len(group[footnote_key]) == 0:
# 如果没有caption和footnote,则不需要将group_id添加到image_body中
group[body_key]['group_id'] = i
maybe_text_image_blocks.append(group[body_key])
continue
else:
group[body_key]['group_id'] = i group[body_key]['group_id'] = i
body_blocks.append(group[body_key]) body_blocks.append(group[body_key])
for caption_block in group[caption_key]: for caption_block in group[caption_key]:
...@@ -21,7 +28,7 @@ def process_groups(groups, body_key, caption_key, footnote_key): ...@@ -21,7 +28,7 @@ def process_groups(groups, body_key, caption_key, footnote_key):
for footnote_block in group[footnote_key]: for footnote_block in group[footnote_key]:
footnote_block['group_id'] = i footnote_block['group_id'] = i
footnote_blocks.append(footnote_block) footnote_blocks.append(footnote_block)
return body_blocks, caption_blocks, footnote_blocks return body_blocks, caption_blocks, footnote_blocks, maybe_text_image_blocks
def prepare_block_bboxes( def prepare_block_bboxes(
......
...@@ -148,17 +148,6 @@ def calculate_iou(bbox1, bbox2): ...@@ -148,17 +148,6 @@ def calculate_iou(bbox1, bbox2):
return iou return iou
def _is_in(box1, box2) -> bool:
"""box1是否完全在box2里面."""
x0_1, y0_1, x1_1, y1_1 = box1
x0_2, y0_2, x1_2, y1_2 = box2
return (x0_1 >= x0_2 and # box1的左边界不在box2的左边外
y0_1 >= y0_2 and # box1的上边界不在box2的上边外
x1_1 <= x1_2 and # box1的右边界不在box2的右边外
y1_1 <= y1_2) # box1的下边界不在box2的下边外
def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2): def calculate_overlap_area_in_bbox1_area_ratio(bbox1, bbox2):
"""计算box1和box2的重叠面积占bbox1的比例.""" """计算box1和box2的重叠面积占bbox1的比例."""
# Determine the coordinates of the intersection rectangle # Determine the coordinates of the intersection rectangle
......
from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, _is_in from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in
from mineru.utils.enum_class import CategoryId, ContentType from mineru.utils.enum_class import CategoryId, ContentType
...@@ -156,7 +156,7 @@ class MagicModel: ...@@ -156,7 +156,7 @@ class MagicModel:
for j in range(N): for j in range(N):
if i == j: if i == j:
continue continue
if _is_in(bboxes[i]['bbox'], bboxes[j]['bbox']): if is_in(bboxes[i]['bbox'], bboxes[j]['bbox']):
keep[i] = False keep[i] = False
return [bboxes[i] for i in range(N) if keep[i]] return [bboxes[i] for i in range(N) if keep[i]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment