Commit 07edefaa authored by myhloli's avatar myhloli
Browse files

feat(model): add text region handling and improve overlap resolution

- Add text region handling in get_res_list_from_layout_res function
- Implement remove_overlaps_min_blocks function to handle overlapping blocks
- Update OCR region handling to include text regions
- Improve overlap resolution for all regions in layout results
parent 73ccfbbf
...@@ -2,6 +2,8 @@ import time ...@@ -2,6 +2,8 @@ import time
import torch import torch
from loguru import logger from loguru import logger
import numpy as np import numpy as np
from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio
from magic_pdf.libs.clean_memory import clean_memory from magic_pdf.libs.clean_memory import clean_memory
...@@ -188,9 +190,46 @@ def filter_nested_tables(table_res_list, overlap_threshold=0.8, area_threshold=0 ...@@ -188,9 +190,46 @@ def filter_nested_tables(table_res_list, overlap_threshold=0.8, area_threshold=0
return [table for i, table in enumerate(table_res_list) if i not in big_tables_idx] return [table for i, table in enumerate(table_res_list) if i not in big_tables_idx]
def remove_overlaps_min_blocks(res_list):
# 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。
# 删除重叠blocks中较小的那些
need_remove = []
for res1 in res_list:
for res2 in res_list:
if res1 != res2:
overlap_box = get_minbox_if_overlap_by_ratio(
res1['bbox'], res2['bbox'], 0.8
)
if overlap_box is not None:
res_to_remove = next(
(res for res in res_list if res['bbox'] == overlap_box),
None,
)
if (
res_to_remove is not None
and res_to_remove not in need_remove
):
large_res = res1 if res1 != res_to_remove else res2
x1, y1, x2, y2 = large_res['bbox']
sx1, sy1, sx2, sy2 = res_to_remove['bbox']
x1 = min(x1, sx1)
y1 = min(y1, sy1)
x2 = max(x2, sx2)
y2 = max(y2, sy2)
large_res['bbox'] = [x1, y1, x2, y2]
need_remove.append(res_to_remove)
if len(need_remove) > 0:
for res in need_remove:
res_list.remove(res)
return res_list, need_remove
def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshold=0.8, area_threshold=0.8): def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshold=0.8, area_threshold=0.8):
"""Extract OCR, table and other regions from layout results.""" """Extract OCR, table and other regions from layout results."""
ocr_res_list = [] ocr_res_list = []
text_res_list = []
table_res_list = [] table_res_list = []
table_indices = [] table_indices = []
single_page_mfdetrec_res = [] single_page_mfdetrec_res = []
...@@ -204,11 +243,14 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol ...@@ -204,11 +243,14 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
"bbox": [int(res['poly'][0]), int(res['poly'][1]), "bbox": [int(res['poly'][0]), int(res['poly'][1]),
int(res['poly'][4]), int(res['poly'][5])], int(res['poly'][4]), int(res['poly'][5])],
}) })
elif category_id in [0, 1, 2, 4, 6, 7]: # OCR regions elif category_id in [0, 2, 4, 6, 7]: # OCR regions
ocr_res_list.append(res) ocr_res_list.append(res)
elif category_id == 5: # Table regions elif category_id == 5: # Table regions
table_res_list.append(res) table_res_list.append(res)
table_indices.append(i) table_indices.append(i)
elif category_id in [1]: # Text regions
res['bbox'] = [int(res['poly'][0]), int(res['poly'][1]), int(res['poly'][4]), int(res['poly'][5])]
text_res_list.append(res)
# Process tables: merge high IoU tables first, then filter nested tables # Process tables: merge high IoU tables first, then filter nested tables
table_res_list, table_indices = merge_high_iou_tables( table_res_list, table_indices = merge_high_iou_tables(
...@@ -226,6 +268,22 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol ...@@ -226,6 +268,22 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
for idx in sorted(to_remove, reverse=True): for idx in sorted(to_remove, reverse=True):
del layout_res[idx] del layout_res[idx]
# Remove overlaps in OCR and text regions
text_res_list, need_remove = remove_overlaps_min_blocks(text_res_list)
for res in text_res_list:
# 将res的poly使用bbox重构
res['poly'] = [res['bbox'][0], res['bbox'][1], res['bbox'][2], res['bbox'][1],
res['bbox'][2], res['bbox'][3], res['bbox'][0], res['bbox'][3]]
# 删除res的bbox
del res['bbox']
ocr_res_list.extend(text_res_list)
if len(need_remove) > 0:
for res in need_remove:
del res['bbox']
layout_res.remove(res)
return ocr_res_list, filtered_table_res_list, single_page_mfdetrec_res return ocr_res_list, filtered_table_res_list, single_page_mfdetrec_res
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment