Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
8e8103a8
Unverified
Commit
8e8103a8
authored
Apr 09, 2025
by
Xiaomeng Zhao
Committed by
GitHub
Apr 09, 2025
Browse files
Merge pull request #2170 from myhloli/dev
feat(model): improve table recognition by merging and filtering tables
parents
56499151
df7ae404
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
189 additions
and
8 deletions
+189
-8
magic_pdf/model/batch_analyze.py
magic_pdf/model/batch_analyze.py
+0
-1
magic_pdf/model/sub_modules/model_utils.py
magic_pdf/model/sub_modules/model_utils.py
+189
-7
No files found.
magic_pdf/model/batch_analyze.py
View file @
8e8103a8
...
...
@@ -150,7 +150,6 @@ class BatchAnalyze:
# 表格识别 table recognition
if
self
.
model
.
apply_table
:
table_start
=
time
.
time
()
table_count
=
0
# for table_res_list_dict in table_res_list_all_page:
for
table_res_dict
in
tqdm
(
table_res_list_all_page
,
desc
=
"Table Predict"
):
_lang
=
table_res_dict
[
'lang'
]
...
...
magic_pdf/model/sub_modules/model_utils.py
View file @
8e8103a8
...
...
@@ -29,22 +29,204 @@ def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
return
return_image
,
return_list
# Select regions for OCR / formula regions / table regions
def
get_res_list_from_layout_res
(
layout_res
):
def
get_coords_and_area
(
table
):
"""Extract coordinates and area from a table."""
xmin
,
ymin
=
int
(
table
[
'poly'
][
0
]),
int
(
table
[
'poly'
][
1
])
xmax
,
ymax
=
int
(
table
[
'poly'
][
4
]),
int
(
table
[
'poly'
][
5
])
area
=
(
xmax
-
xmin
)
*
(
ymax
-
ymin
)
return
xmin
,
ymin
,
xmax
,
ymax
,
area
def
calculate_intersection
(
box1
,
box2
):
"""Calculate intersection coordinates between two boxes."""
intersection_xmin
=
max
(
box1
[
0
],
box2
[
0
])
intersection_ymin
=
max
(
box1
[
1
],
box2
[
1
])
intersection_xmax
=
min
(
box1
[
2
],
box2
[
2
])
intersection_ymax
=
min
(
box1
[
3
],
box2
[
3
])
# Check if intersection is valid
if
intersection_xmax
<=
intersection_xmin
or
intersection_ymax
<=
intersection_ymin
:
return
None
return
intersection_xmin
,
intersection_ymin
,
intersection_xmax
,
intersection_ymax
def
calculate_iou
(
box1
,
box2
):
"""Calculate IoU between two boxes."""
intersection
=
calculate_intersection
(
box1
[:
4
],
box2
[:
4
])
if
not
intersection
:
return
0
intersection_xmin
,
intersection_ymin
,
intersection_xmax
,
intersection_ymax
=
intersection
intersection_area
=
(
intersection_xmax
-
intersection_xmin
)
*
(
intersection_ymax
-
intersection_ymin
)
area1
,
area2
=
box1
[
4
],
box2
[
4
]
union_area
=
area1
+
area2
-
intersection_area
return
intersection_area
/
union_area
if
union_area
>
0
else
0
def
is_inside
(
small_box
,
big_box
,
overlap_threshold
=
0.8
):
"""Check if small_box is inside big_box by at least overlap_threshold."""
intersection
=
calculate_intersection
(
small_box
[:
4
],
big_box
[:
4
])
if
not
intersection
:
return
False
intersection_xmin
,
intersection_ymin
,
intersection_xmax
,
intersection_ymax
=
intersection
intersection_area
=
(
intersection_xmax
-
intersection_xmin
)
*
(
intersection_ymax
-
intersection_ymin
)
# Check if overlap exceeds threshold
return
intersection_area
>=
overlap_threshold
*
small_box
[
4
]
def
do_overlap
(
box1
,
box2
):
"""Check if two boxes overlap."""
return
calculate_intersection
(
box1
[:
4
],
box2
[:
4
])
is
not
None
def
merge_high_iou_tables
(
table_res_list
,
layout_res
,
table_indices
,
iou_threshold
=
0.7
):
"""Merge tables with IoU > threshold."""
if
len
(
table_res_list
)
<
2
:
return
table_res_list
,
table_indices
table_info
=
[
get_coords_and_area
(
table
)
for
table
in
table_res_list
]
merged
=
True
while
merged
:
merged
=
False
i
=
0
while
i
<
len
(
table_res_list
)
-
1
:
j
=
i
+
1
while
j
<
len
(
table_res_list
):
iou
=
calculate_iou
(
table_info
[
i
],
table_info
[
j
])
if
iou
>
iou_threshold
:
# Merge tables by taking their union
x1_min
,
y1_min
,
x1_max
,
y1_max
,
_
=
table_info
[
i
]
x2_min
,
y2_min
,
x2_max
,
y2_max
,
_
=
table_info
[
j
]
union_xmin
=
min
(
x1_min
,
x2_min
)
union_ymin
=
min
(
y1_min
,
y2_min
)
union_xmax
=
max
(
x1_max
,
x2_max
)
union_ymax
=
max
(
y1_max
,
y2_max
)
# Create merged table
merged_table
=
table_res_list
[
i
].
copy
()
merged_table
[
'poly'
][
0
]
=
union_xmin
merged_table
[
'poly'
][
1
]
=
union_ymin
merged_table
[
'poly'
][
2
]
=
union_xmax
merged_table
[
'poly'
][
3
]
=
union_ymin
merged_table
[
'poly'
][
4
]
=
union_xmax
merged_table
[
'poly'
][
5
]
=
union_ymax
merged_table
[
'poly'
][
6
]
=
union_xmin
merged_table
[
'poly'
][
7
]
=
union_ymax
# Update layout_res
to_remove
=
[
table_indices
[
j
],
table_indices
[
i
]]
for
idx
in
sorted
(
to_remove
,
reverse
=
True
):
del
layout_res
[
idx
]
layout_res
.
append
(
merged_table
)
# Update tracking lists
table_indices
=
[
k
if
k
<
min
(
to_remove
)
else
k
-
1
if
k
<
max
(
to_remove
)
else
k
-
2
if
k
>
max
(
to_remove
)
else
len
(
layout_res
)
-
1
for
k
in
table_indices
if
k
not
in
to_remove
]
table_indices
.
append
(
len
(
layout_res
)
-
1
)
# Update table lists
table_res_list
.
pop
(
j
)
table_res_list
.
pop
(
i
)
table_res_list
.
append
(
merged_table
)
# Update table_info
table_info
=
[
get_coords_and_area
(
table
)
for
table
in
table_res_list
]
merged
=
True
break
j
+=
1
if
merged
:
break
i
+=
1
return
table_res_list
,
table_indices
def
filter_nested_tables
(
table_res_list
,
overlap_threshold
=
0.8
,
area_threshold
=
0.8
):
"""Remove big tables containing multiple smaller tables within them."""
if
len
(
table_res_list
)
<
3
:
return
table_res_list
table_info
=
[
get_coords_and_area
(
table
)
for
table
in
table_res_list
]
big_tables_idx
=
[]
for
i
in
range
(
len
(
table_res_list
)):
# Find tables inside this one
tables_inside
=
[
j
for
j
in
range
(
len
(
table_res_list
))
if
i
!=
j
and
is_inside
(
table_info
[
j
],
table_info
[
i
],
overlap_threshold
)]
# Continue if there are at least 2 tables inside
if
len
(
tables_inside
)
>=
2
:
# Check if inside tables overlap with each other
tables_overlap
=
any
(
do_overlap
(
table_info
[
tables_inside
[
idx1
]],
table_info
[
tables_inside
[
idx2
]])
for
idx1
in
range
(
len
(
tables_inside
))
for
idx2
in
range
(
idx1
+
1
,
len
(
tables_inside
)))
# If no overlaps, check area condition
if
not
tables_overlap
:
total_inside_area
=
sum
(
table_info
[
j
][
4
]
for
j
in
tables_inside
)
big_table_area
=
table_info
[
i
][
4
]
if
total_inside_area
>
area_threshold
*
big_table_area
:
big_tables_idx
.
append
(
i
)
return
[
table
for
i
,
table
in
enumerate
(
table_res_list
)
if
i
not
in
big_tables_idx
]
def
get_res_list_from_layout_res
(
layout_res
,
iou_threshold
=
0.7
,
overlap_threshold
=
0.8
,
area_threshold
=
0.8
):
"""Extract OCR, table and other regions from layout results."""
ocr_res_list
=
[]
table_res_list
=
[]
table_indices
=
[]
single_page_mfdetrec_res
=
[]
for
res
in
layout_res
:
if
int
(
res
[
'category_id'
])
in
[
13
,
14
]:
# Categorize regions
for
i
,
res
in
enumerate
(
layout_res
):
category_id
=
int
(
res
[
'category_id'
])
if
category_id
in
[
13
,
14
]:
# Formula regions
single_page_mfdetrec_res
.
append
({
"bbox"
:
[
int
(
res
[
'poly'
][
0
]),
int
(
res
[
'poly'
][
1
]),
int
(
res
[
'poly'
][
4
]),
int
(
res
[
'poly'
][
5
])],
})
elif
int
(
res
[
'
category_id
'
])
in
[
0
,
1
,
2
,
4
,
6
,
7
]:
elif
category_id
in
[
0
,
1
,
2
,
4
,
6
,
7
]:
# OCR regions
ocr_res_list
.
append
(
res
)
elif
int
(
res
[
'
category_id
'
])
in
[
5
]:
elif
category_id
==
5
:
# Table regions
table_res_list
.
append
(
res
)
return
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
table_indices
.
append
(
i
)
# Process tables: merge high IoU tables first, then filter nested tables
table_res_list
,
table_indices
=
merge_high_iou_tables
(
table_res_list
,
layout_res
,
table_indices
,
iou_threshold
)
filtered_table_res_list
=
filter_nested_tables
(
table_res_list
,
overlap_threshold
,
area_threshold
)
# Remove filtered out tables from layout_res
if
len
(
filtered_table_res_list
)
<
len
(
table_res_list
):
kept_tables
=
set
(
id
(
table
)
for
table
in
filtered_table_res_list
)
to_remove
=
[
table_indices
[
i
]
for
i
,
table
in
enumerate
(
table_res_list
)
if
id
(
table
)
not
in
kept_tables
]
for
idx
in
sorted
(
to_remove
,
reverse
=
True
):
del
layout_res
[
idx
]
return
ocr_res_list
,
filtered_table_res_list
,
single_page_mfdetrec_res
def
clean_vram
(
device
,
vram_threshold
=
8
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment