Commit 6094699c authored by myhloli's avatar myhloli
Browse files

refactor: enhance overlap handling in pipeline_magic_model.py for image and table bodies

parent f2666385
from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in from mineru.utils.boxbase import bbox_relative_pos, calculate_iou, bbox_distance, is_in, get_minbox_if_overlap_by_ratio
from mineru.utils.enum_class import CategoryId, ContentType from mineru.utils.enum_class import CategoryId, ContentType
...@@ -13,7 +13,54 @@ class MagicModel: ...@@ -13,7 +13,54 @@ class MagicModel:
self.__fix_by_remove_low_confidence() self.__fix_by_remove_low_confidence()
"""删除高iou(>0.9)数据中置信度较低的那个""" """删除高iou(>0.9)数据中置信度较低的那个"""
self.__fix_by_remove_high_iou_and_low_confidence() self.__fix_by_remove_high_iou_and_low_confidence()
"""将部分tbale_footnote修正为image_footnote"""
self.__fix_footnote() self.__fix_footnote()
"""处理重叠的image_body和table_body"""
self.__fix_by_remove_overlap_image_table_body()
def __fix_by_remove_overlap_image_table_body(self):
need_remove_list = []
layout_dets = self.__page_model_info['layout_dets']
image_blocks = list(filter(
lambda x: x['category_id'] == CategoryId.ImageBody, layout_dets
))
table_blocks = list(filter(
lambda x: x['category_id'] == CategoryId.TableBody, layout_dets
))
def add_need_remove_block(blocks):
for i in range(len(blocks)):
for j in range(i + 1, len(blocks)):
block1 = blocks[i]
block2 = blocks[j]
overlap_box = get_minbox_if_overlap_by_ratio(
block1['bbox'], block2['bbox'], 0.8
)
if overlap_box is not None:
block_to_remove = next(
(block for block in blocks if block['bbox'] == overlap_box),
None,
)
if (
block_to_remove is not None
and block_to_remove not in need_remove_list
):
large_block = block1 if block1 != block_to_remove else block2
x1, y1, x2, y2 = large_block['bbox']
sx1, sy1, sx2, sy2 = block_to_remove['bbox']
x1 = min(x1, sx1)
y1 = min(y1, sy1)
x2 = max(x2, sx2)
y2 = max(y2, sy2)
large_block['bbox'] = [x1, y1, x2, y2]
need_remove_list.append(block_to_remove)
add_need_remove_block(image_blocks)
add_need_remove_block(table_blocks)
for need_remove in need_remove_list:
layout_dets.remove(need_remove)
def __fix_axis(self): def __fix_axis(self):
need_remove_list = [] need_remove_list = []
...@@ -46,42 +93,46 @@ class MagicModel: ...@@ -46,42 +93,46 @@ class MagicModel:
def __fix_by_remove_high_iou_and_low_confidence(self): def __fix_by_remove_high_iou_and_low_confidence(self):
need_remove_list = [] need_remove_list = []
layout_dets = self.__page_model_info['layout_dets'] layout_dets = list(filter(
lambda x: x['category_id'] in [
CategoryId.Title,
CategoryId.Text,
CategoryId.ImageBody,
CategoryId.ImageCaption,
CategoryId.TableBody,
CategoryId.TableCaption,
CategoryId.TableFootnote,
CategoryId.InterlineEquation_Layout,
CategoryId.InterlineEquationNumber_Layout,
], self.__page_model_info['layout_dets']
)
)
for i in range(len(layout_dets)): for i in range(len(layout_dets)):
for j in range(i + 1, len(layout_dets)): for j in range(i + 1, len(layout_dets)):
layout_det1 = layout_dets[i] layout_det1 = layout_dets[i]
layout_det2 = layout_dets[j] layout_det2 = layout_dets[j]
if layout_det1['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] and layout_det2['category_id'] in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
if ( if calculate_iou(layout_det1['bbox'], layout_det2['bbox']) > 0.9:
calculate_iou(layout_det1['bbox'], layout_det2['bbox'])
> 0.9 layout_det_need_remove = layout_det1 if layout_det1['score'] < layout_det2['score'] else layout_det2
):
if layout_det1['score'] < layout_det2['score']: if layout_det_need_remove not in need_remove_list:
layout_det_need_remove = layout_det1 need_remove_list.append(layout_det_need_remove)
else:
layout_det_need_remove = layout_det2
if layout_det_need_remove not in need_remove_list:
need_remove_list.append(layout_det_need_remove)
else:
continue
else:
continue
for need_remove in need_remove_list: for need_remove in need_remove_list:
layout_dets.remove(need_remove) self.__page_model_info['layout_dets'].remove(need_remove)
def __fix_footnote(self): def __fix_footnote(self):
# 3: figure, 5: table, 7: footnote
footnotes = [] footnotes = []
figures = [] figures = []
tables = [] tables = []
for obj in self.__page_model_info['layout_dets']: for obj in self.__page_model_info['layout_dets']:
if obj['category_id'] == 7: if obj['category_id'] == CategoryId.TableFootnote:
footnotes.append(obj) footnotes.append(obj)
elif obj['category_id'] == 3: elif obj['category_id'] == CategoryId.ImageBody:
figures.append(obj) figures.append(obj)
elif obj['category_id'] == 5: elif obj['category_id'] == CategoryId.TableBody:
tables.append(obj) tables.append(obj)
if len(footnotes) * len(figures) == 0: if len(footnotes) * len(figures) == 0:
continue continue
...@@ -314,10 +365,10 @@ class MagicModel: ...@@ -314,10 +365,10 @@ class MagicModel:
def get_imgs(self): def get_imgs(self):
with_captions = self.__tie_up_category_by_distance_v3( with_captions = self.__tie_up_category_by_distance_v3(
3, 4 CategoryId.ImageBody, CategoryId.ImageCaption
) )
with_footnotes = self.__tie_up_category_by_distance_v3( with_footnotes = self.__tie_up_category_by_distance_v3(
3, CategoryId.ImageFootnote CategoryId.ImageBody, CategoryId.ImageFootnote
) )
ret = [] ret = []
for v in with_captions: for v in with_captions:
...@@ -333,10 +384,10 @@ class MagicModel: ...@@ -333,10 +384,10 @@ class MagicModel:
def get_tables(self) -> list: def get_tables(self) -> list:
with_captions = self.__tie_up_category_by_distance_v3( with_captions = self.__tie_up_category_by_distance_v3(
5, 6 CategoryId.TableBody, CategoryId.TableCaption
) )
with_footnotes = self.__tie_up_category_by_distance_v3( with_footnotes = self.__tie_up_category_by_distance_v3(
5, 7 CategoryId.TableBody, CategoryId.TableFootnote
) )
ret = [] ret = []
for v in with_captions: for v in with_captions:
...@@ -385,20 +436,21 @@ class MagicModel: ...@@ -385,20 +436,21 @@ class MagicModel:
all_spans = [] all_spans = []
layout_dets = self.__page_model_info['layout_dets'] layout_dets = self.__page_model_info['layout_dets']
allow_category_id_list = [3, 5, 13, 14, 15] allow_category_id_list = [
CategoryId.ImageBody,
CategoryId.TableBody,
CategoryId.InlineEquation,
CategoryId.InterlineEquation_YOLO,
CategoryId.OcrText,
]
"""当成span拼接的""" """当成span拼接的"""
# 3: 'image', # 图片
# 5: 'table', # 表格
# 13: 'inline_equation', # 行内公式
# 14: 'interline_equation', # 行间公式
# 15: 'text', # ocr识别文本
for layout_det in layout_dets: for layout_det in layout_dets:
category_id = layout_det['category_id'] category_id = layout_det['category_id']
if category_id in allow_category_id_list: if category_id in allow_category_id_list:
span = {'bbox': layout_det['bbox'], 'score': layout_det['score']} span = {'bbox': layout_det['bbox'], 'score': layout_det['score']}
if category_id == 3: if category_id == CategoryId.ImageBody:
span['type'] = ContentType.IMAGE span['type'] = ContentType.IMAGE
elif category_id == 5: elif category_id == CategoryId.TableBody:
# 获取table模型结果 # 获取table模型结果
latex = layout_det.get('latex', None) latex = layout_det.get('latex', None)
html = layout_det.get('html', None) html = layout_det.get('html', None)
...@@ -407,13 +459,13 @@ class MagicModel: ...@@ -407,13 +459,13 @@ class MagicModel:
elif html: elif html:
span['html'] = html span['html'] = html
span['type'] = ContentType.TABLE span['type'] = ContentType.TABLE
elif category_id == 13: elif category_id == CategoryId.InlineEquation:
span['content'] = layout_det['latex'] span['content'] = layout_det['latex']
span['type'] = ContentType.INLINE_EQUATION span['type'] = ContentType.INLINE_EQUATION
elif category_id == 14: elif category_id == CategoryId.InterlineEquation_YOLO:
span['content'] = layout_det['latex'] span['content'] = layout_det['latex']
span['type'] = ContentType.INTERLINE_EQUATION span['type'] = ContentType.INTERLINE_EQUATION
elif category_id == 15: elif category_id == CategoryId.OcrText:
span['content'] = layout_det['text'] span['content'] = layout_det['text']
span['type'] = ContentType.TEXT span['type'] = ContentType.TEXT
all_spans.append(span) all_spans.append(span)
...@@ -438,4 +490,4 @@ class MagicModel: ...@@ -438,4 +490,4 @@ class MagicModel:
for col in extra_col: for col in extra_col:
block[col] = item.get(col, None) block[col] = item.get(col, None)
blocks.append(block) blocks.append(block)
return blocks return blocks
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment