Commit 6094699c authored by myhloli's avatar myhloli
Browse files

refactor: enhance overlap handling in pipeline_magic_model.py for image and table bodies

parent f2666385
from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in
from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in, get_minbox_if_overlap_by_ratio
from mineru.utils.enum_class import CategoryId, ContentType
......@@ -13,7 +13,54 @@ class MagicModel:
self.__fix_by_remove_low_confidence()
"""删除高iou(>0.9)数据中置信度较低的那个"""
self.__fix_by_remove_high_iou_and_low_confidence()
"""将部分tbale_footnote修正为image_footnote"""
self.__fix_footnote()
"""处理重叠的image_body和table_body"""
self.__fix_by_remove_overlap_image_table_body()
def __fix_by_remove_overlap_image_table_body(self):
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
image_blocks = list(filter(
lambda x: x['category_id'] == CategoryId.ImageBody, layout_dets
))
table_blocks = list(filter(
lambda x: x['category_id'] == CategoryId.TableBody, layout_dets
))
def add_need_remove_block(blocks):
for i in range(len(blocks)):
for j in range(i + 1, len(blocks)):
block1 = blocks[i]
block2 = blocks[j]
overlap_box = get_minbox_if_overlap_by_ratio(
block1['bbox'], block2['bbox'], 0.8
)
if overlap_box is not None:
block_to_remove = next(
(block for block in blocks if block['bbox'] == overlap_box),
None,
)
if (
block_to_remove is not None
and block_to_remove not in need_remove_list
):
large_block = block1 if block1 != block_to_remove else block2
x1, y1, x2, y2 = large_block['bbox']
sx1, sy1, sx2, sy2 = block_to_remove['bbox']
x1 = min(x1, sx1)
y1 = min(y1, sy1)
x2 = max(x2, sx2)
y2 = max(y2, sy2)
large_block['bbox'] = [x1, y1, x2, y2]
need_remove_list.append(block_to_remove)
add_need_remove_block(image_blocks)
add_need_remove_block(table_blocks)
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_axis(self):
need_remove_list = []
......@@ -46,42 +93,46 @@ class MagicModel:
def __fix_by_remove_high_iou_and_low_confidence(self):
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
layout_dets = list(filter(
lambda x: x['category_id'] in [
CategoryId.Title,
CategoryId.Text,
CategoryId.ImageBody,
CategoryId.ImageCaption,
CategoryId.TableBody,
CategoryId.TableCaption,
CategoryId.TableFootnote,
CategoryId.InterlineEquation_Layout,
CategoryId.InterlineEquationNumber_Layout,
], self.__page_model_info['layout_dets']
)
)
for i in range(len(layout_dets)):
for j in range(i + 1, len(layout_dets)):
layout_det1 = layout_dets[i]
layout_det2 = layout_dets[j]
if layout_det1['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
if (
calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
> 0.9
):
if layout_det1['score'] < layout_det2['score']:
layout_det_need_remove = layout_det1
else:
layout_det_need_remove = layout_det2
if layout_det_need_remove not in need_remove_list:
need_remove_list.append(layout_det_need_remove)
else:
continue
else:
continue
if calculate_iou(layout_det1['bbox'], layout_det2['bbox']) > 0.9:
layout_det_need_remove = layout_det1 if layout_det1['score'] < layout_det2['score'] else layout_det2
if layout_det_need_remove not in need_remove_list:
need_remove_list.append(layout_det_need_remove)
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
self.__page_model_info['layout_dets'].remove(need_remove)
def __fix_footnote(self):
# 3: figure, 5: table, 7: footnote
footnotes = []
figures = []
tables = []
for obj in self.__page_model_info['layout_dets']:
if obj['category_id'] == 7:
if obj['category_id'] == CategoryId.TableFootnote:
footnotes.append(obj)
elif obj['category_id'] == 3:
elif obj['category_id'] == CategoryId.ImageBody:
figures.append(obj)
elif obj['category_id'] == 5:
elif obj['category_id'] == CategoryId.TableBody:
tables.append(obj)
if len(footnotes) * len(figures) == 0:
continue
......@@ -314,10 +365,10 @@ class MagicModel:
def get_imgs(self):
with_captions = self.__tie_up_category_by_distance_v3(
3, 4
CategoryId.ImageBody, CategoryId.ImageCaption
)
with_footnotes = self.__tie_up_category_by_distance_v3(
3, CategoryId.ImageFootnote
CategoryId.ImageBody, CategoryId.ImageFootnote
)
ret = []
for v in with_captions:
......@@ -333,10 +384,10 @@ class MagicModel:
def get_tables(self) -> list:
with_captions = self.__tie_up_category_by_distance_v3(
5, 6
CategoryId.TableBody, CategoryId.TableCaption
)
with_footnotes = self.__tie_up_category_by_distance_v3(
5, 7
CategoryId.TableBody, CategoryId.TableFootnote
)
ret = []
for v in with_captions:
......@@ -385,20 +436,21 @@ class MagicModel:
all_spans = []
layout_dets = self.__page_model_info['layout_dets']
allow_category_id_list = [3, 5, 13, 14, 15]
allow_category_id_list = [
CategoryId.ImageBody,
CategoryId.TableBody,
CategoryId.InlineEquation,
CategoryId.InterlineEquation_YOLO,
CategoryId.OcrText,
]
"""当成span拼接的"""
# 3: 'image', # 图片
# 5: 'table', # 表格
# 13: 'inline_equation', # 行内公式
# 14: 'interline_equation', # 行间公式
# 15: 'text', # ocr识别文本
for layout_det in layout_dets:
category_id = layout_det['category_id']
if category_id in allow_category_id_list:
span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
if category_id == 3:
if category_id == CategoryId.ImageBody:
span['type'] = ContentType.IMAGE
elif category_id == 5:
elif category_id == CategoryId.TableBody:
# 获取table模型结果
latex = layout_det.get('latex', None)
html = layout_det.get('html', None)
......@@ -407,13 +459,13 @@ class MagicModel:
elif html:
span['html'] = html
span['type'] = ContentType.TABLE
elif category_id == 13:
elif category_id == CategoryId.InlineEquation:
span['content'] = layout_det['latex']
span['type'] = ContentType.INLINE_EQUATION
elif category_id == 14:
elif category_id == CategoryId.InterlineEquation_YOLO:
span['content'] = layout_det['latex']
span['type'] = ContentType.INTERLINE_EQUATION
elif category_id == 15:
elif category_id == CategoryId.OcrText:
span['content'] = layout_det['text']
span['type'] = ContentType.TEXT
all_spans.append(span)
......@@ -438,4 +490,4 @@ class MagicModel:
for col in extra_col:
block[col] = item.get(col, None)
blocks.append(block)
return blocks
return blocks
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment