Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
9bf8be98
Commit
9bf8be98
authored
Jul 23, 2025
by
myhloli
Browse files
feat: add utility functions for bounding box processing in magic_model_utils.py
parent
4f612cbc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
282 additions
and
372 deletions
+282
-372
mineru/backend/pipeline/pipeline_magic_model.py
mineru/backend/pipeline/pipeline_magic_model.py
+13
-203
mineru/backend/vlm/vlm_magic_model.py
mineru/backend/vlm/vlm_magic_model.py
+8
-169
mineru/utils/magic_model_utils.py
mineru/utils/magic_model_utils.py
+261
-0
No files found.
mineru/backend/pipeline/pipeline_magic_model.py
View file @
9bf8be98
from
mineru.utils.boxbase
import
bbox_relative_pos
,
calculate_iou
,
bbox_distance
,
is_in
,
get_minbox_if_overlap_by_ratio
from
mineru.utils.boxbase
import
bbox_relative_pos
,
get_minbox_if_overlap_by_ratio
from
mineru.utils.enum_class
import
CategoryId
,
ContentType
from
mineru.utils.enum_class
import
CategoryId
,
ContentType
import
mineru.utils.magic_model_utils
as
magic_model_utils
class
MagicModel
:
class
MagicModel
:
...
@@ -89,18 +90,9 @@ class MagicModel:
...
@@ -89,18 +90,9 @@ class MagicModel:
layout_dets
.
remove
(
need_remove
)
layout_dets
.
remove
(
need_remove
)
def
__fix_by_remove_low_confidence
(
self
):
def
__fix_by_remove_low_confidence
(
self
):
need_remove_list
=
[]
magic_model_utils
.
remove_low_confidence
(
self
.
__page_model_info
[
'layout_dets'
])
layout_dets
=
self
.
__page_model_info
[
'layout_dets'
]
for
layout_det
in
layout_dets
:
if
layout_det
[
'score'
]
<=
0.05
:
need_remove_list
.
append
(
layout_det
)
else
:
continue
for
need_remove
in
need_remove_list
:
layout_dets
.
remove
(
need_remove
)
def
__fix_by_remove_high_iou_and_low_confidence
(
self
):
def
__fix_by_remove_high_iou_and_low_confidence
(
self
):
need_remove_list
=
[]
layout_dets
=
list
(
filter
(
layout_dets
=
list
(
filter
(
lambda
x
:
x
[
'category_id'
]
in
[
lambda
x
:
x
[
'category_id'
]
in
[
CategoryId
.
Title
,
CategoryId
.
Title
,
...
@@ -115,20 +107,7 @@ class MagicModel:
...
@@ -115,20 +107,7 @@ class MagicModel:
],
self
.
__page_model_info
[
'layout_dets'
]
],
self
.
__page_model_info
[
'layout_dets'
]
)
)
)
)
for
i
in
range
(
len
(
layout_dets
)):
magic_model_utils
.
remove_high_iou_low_confidence
(
layout_dets
)
for
j
in
range
(
i
+
1
,
len
(
layout_dets
)):
layout_det1
=
layout_dets
[
i
]
layout_det2
=
layout_dets
[
j
]
if
calculate_iou
(
layout_det1
[
'bbox'
],
layout_det2
[
'bbox'
])
>
0.9
:
layout_det_need_remove
=
layout_det1
if
layout_det1
[
'score'
]
<
layout_det2
[
'score'
]
else
layout_det2
if
layout_det_need_remove
not
in
need_remove_list
:
need_remove_list
.
append
(
layout_det_need_remove
)
for
need_remove
in
need_remove_list
:
self
.
__page_model_info
[
'layout_dets'
].
remove
(
need_remove
)
def
__fix_footnote
(
self
):
def
__fix_footnote
(
self
):
footnotes
=
[]
footnotes
=
[]
...
@@ -162,7 +141,7 @@ class MagicModel:
...
@@ -162,7 +141,7 @@ class MagicModel:
if
pos_flag_count
>
1
:
if
pos_flag_count
>
1
:
continue
continue
dis_figure_footnote
[
i
]
=
min
(
dis_figure_footnote
[
i
]
=
min
(
self
.
_bbox_distance
(
figures
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]),
magic_model_utils
.
bbox_distance_with_relative_check
(
figures
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]),
dis_figure_footnote
.
get
(
i
,
float
(
'inf'
)),
dis_figure_footnote
.
get
(
i
,
float
(
'inf'
)),
)
)
for
i
in
range
(
len
(
footnotes
)):
for
i
in
range
(
len
(
footnotes
)):
...
@@ -181,7 +160,7 @@ class MagicModel:
...
@@ -181,7 +160,7 @@ class MagicModel:
continue
continue
dis_table_footnote
[
i
]
=
min
(
dis_table_footnote
[
i
]
=
min
(
self
.
_bbox_distance
(
tables
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]),
magic_model_utils
.
bbox_distance_with_relative_check
(
tables
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]),
dis_table_footnote
.
get
(
i
,
float
(
'inf'
)),
dis_table_footnote
.
get
(
i
,
float
(
'inf'
)),
)
)
for
i
in
range
(
len
(
footnotes
)):
for
i
in
range
(
len
(
footnotes
)):
...
@@ -190,189 +169,20 @@ class MagicModel:
...
@@ -190,189 +169,20 @@ class MagicModel:
if
dis_table_footnote
.
get
(
i
,
float
(
'inf'
))
>
dis_figure_footnote
[
i
]:
if
dis_table_footnote
.
get
(
i
,
float
(
'inf'
))
>
dis_figure_footnote
[
i
]:
footnotes
[
i
][
'category_id'
]
=
CategoryId
.
ImageFootnote
footnotes
[
i
][
'category_id'
]
=
CategoryId
.
ImageFootnote
def
_bbox_distance
(
self
,
bbox1
,
bbox2
):
left
,
right
,
bottom
,
top
=
bbox_relative_pos
(
bbox1
,
bbox2
)
flags
=
[
left
,
right
,
bottom
,
top
]
count
=
sum
([
1
if
v
else
0
for
v
in
flags
])
if
count
>
1
:
return
float
(
'inf'
)
if
left
or
right
:
l1
=
bbox1
[
3
]
-
bbox1
[
1
]
l2
=
bbox2
[
3
]
-
bbox2
[
1
]
else
:
l1
=
bbox1
[
2
]
-
bbox1
[
0
]
l2
=
bbox2
[
2
]
-
bbox2
[
0
]
if
l2
>
l1
and
(
l2
-
l1
)
/
l1
>
0.3
:
return
float
(
'inf'
)
return
bbox_distance
(
bbox1
,
bbox2
)
def
__reduct_overlap
(
self
,
bboxes
):
N
=
len
(
bboxes
)
keep
=
[
True
]
*
N
for
i
in
range
(
N
):
for
j
in
range
(
N
):
if
i
==
j
:
continue
if
is_in
(
bboxes
[
i
][
'bbox'
],
bboxes
[
j
][
'bbox'
]):
keep
[
i
]
=
False
return
[
bboxes
[
i
]
for
i
in
range
(
N
)
if
keep
[
i
]]
def
__tie_up_category_by_distance_v3
(
def
__tie_up_category_by_distance_v3
(
self
,
self
,
subject_category_id
:
int
,
subject_category_id
:
int
,
object_category_id
:
int
,
object_category_id
:
int
,
):
):
subjects
=
self
.
__reduct_overlap
(
return
magic_model_utils
.
tie_up_category_by_distance_v3
(
list
(
self
.
__page_model_info
,
map
(
subject_category_id
,
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
object_category_id
,
filter
(
extract_bbox_func
=
lambda
x
:
x
[
'bbox'
],
lambda
x
:
x
[
'category_id'
]
==
subject_category_id
,
extract_score_func
=
lambda
x
:
x
[
'score'
],
self
.
__page_model_info
[
'layout_dets'
],
create_item_func
=
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]}
),
)
)
)
objects
=
self
.
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
object_category_id
,
self
.
__page_model_info
[
'layout_dets'
],
),
)
)
)
)
ret
=
[]
N
,
M
=
len
(
subjects
),
len
(
objects
)
subjects
.
sort
(
key
=
lambda
x
:
x
[
'bbox'
][
0
]
**
2
+
x
[
'bbox'
][
1
]
**
2
)
objects
.
sort
(
key
=
lambda
x
:
x
[
'bbox'
][
0
]
**
2
+
x
[
'bbox'
][
1
]
**
2
)
OBJ_IDX_OFFSET
=
10000
SUB_BIT_KIND
,
OBJ_BIT_KIND
=
0
,
1
all_boxes_with_idx
=
[(
i
,
SUB_BIT_KIND
,
sub
[
'bbox'
][
0
],
sub
[
'bbox'
][
1
])
for
i
,
sub
in
enumerate
(
subjects
)]
+
[(
i
+
OBJ_IDX_OFFSET
,
OBJ_BIT_KIND
,
obj
[
'bbox'
][
0
],
obj
[
'bbox'
][
1
])
for
i
,
obj
in
enumerate
(
objects
)]
seen_idx
=
set
()
seen_sub_idx
=
set
()
while
N
>
len
(
seen_sub_idx
):
candidates
=
[]
for
idx
,
kind
,
x0
,
y0
in
all_boxes_with_idx
:
if
idx
in
seen_idx
:
continue
candidates
.
append
((
idx
,
kind
,
x0
,
y0
))
if
len
(
candidates
)
==
0
:
break
left_x
=
min
([
v
[
2
]
for
v
in
candidates
])
top_y
=
min
([
v
[
3
]
for
v
in
candidates
])
candidates
.
sort
(
key
=
lambda
x
:
(
x
[
2
]
-
left_x
)
**
2
+
(
x
[
3
]
-
top_y
)
**
2
)
fst_idx
,
fst_kind
,
left_x
,
top_y
=
candidates
[
0
]
fst_bbox
=
subjects
[
fst_idx
][
'bbox'
]
if
fst_kind
==
SUB_BIT_KIND
else
objects
[
fst_idx
-
OBJ_IDX_OFFSET
][
'bbox'
]
candidates
.
sort
(
key
=
lambda
x
:
bbox_distance
(
fst_bbox
,
subjects
[
x
[
0
]][
'bbox'
])
if
x
[
1
]
==
SUB_BIT_KIND
else
bbox_distance
(
fst_bbox
,
objects
[
x
[
0
]
-
OBJ_IDX_OFFSET
][
'bbox'
]))
nxt
=
None
for
i
in
range
(
1
,
len
(
candidates
)):
if
candidates
[
i
][
1
]
^
fst_kind
==
1
:
nxt
=
candidates
[
i
]
break
if
nxt
is
None
:
break
if
fst_kind
==
SUB_BIT_KIND
:
sub_idx
,
obj_idx
=
fst_idx
,
nxt
[
0
]
-
OBJ_IDX_OFFSET
else
:
sub_idx
,
obj_idx
=
nxt
[
0
],
fst_idx
-
OBJ_IDX_OFFSET
pair_dis
=
bbox_distance
(
subjects
[
sub_idx
][
'bbox'
],
objects
[
obj_idx
][
'bbox'
])
nearest_dis
=
float
(
'inf'
)
for
i
in
range
(
N
):
# 取消原先算法中 1对1 匹配的偏置
# if i in seen_idx or i == sub_idx:continue
nearest_dis
=
min
(
nearest_dis
,
bbox_distance
(
subjects
[
i
][
'bbox'
],
objects
[
obj_idx
][
'bbox'
]))
if
pair_dis
>=
3
*
nearest_dis
:
seen_idx
.
add
(
sub_idx
)
continue
seen_idx
.
add
(
sub_idx
)
seen_idx
.
add
(
obj_idx
+
OBJ_IDX_OFFSET
)
seen_sub_idx
.
add
(
sub_idx
)
ret
.
append
(
{
'sub_bbox'
:
{
'bbox'
:
subjects
[
sub_idx
][
'bbox'
],
'score'
:
subjects
[
sub_idx
][
'score'
],
},
'obj_bboxes'
:
[
{
'score'
:
objects
[
obj_idx
][
'score'
],
'bbox'
:
objects
[
obj_idx
][
'bbox'
]}
],
'sub_idx'
:
sub_idx
,
}
)
for
i
in
range
(
len
(
objects
)):
j
=
i
+
OBJ_IDX_OFFSET
if
j
in
seen_idx
:
continue
seen_idx
.
add
(
j
)
nearest_dis
,
nearest_sub_idx
=
float
(
'inf'
),
-
1
for
k
in
range
(
len
(
subjects
)):
dis
=
bbox_distance
(
objects
[
i
][
'bbox'
],
subjects
[
k
][
'bbox'
])
if
dis
<
nearest_dis
:
nearest_dis
=
dis
nearest_sub_idx
=
k
for
k
in
range
(
len
(
subjects
)):
if
k
!=
nearest_sub_idx
:
continue
if
k
in
seen_sub_idx
:
for
kk
in
range
(
len
(
ret
)):
if
ret
[
kk
][
'sub_idx'
]
==
k
:
ret
[
kk
][
'obj_bboxes'
].
append
({
'score'
:
objects
[
i
][
'score'
],
'bbox'
:
objects
[
i
][
'bbox'
]})
break
else
:
ret
.
append
(
{
'sub_bbox'
:
{
'bbox'
:
subjects
[
k
][
'bbox'
],
'score'
:
subjects
[
k
][
'score'
],
},
'obj_bboxes'
:
[
{
'score'
:
objects
[
i
][
'score'
],
'bbox'
:
objects
[
i
][
'bbox'
]}
],
'sub_idx'
:
k
,
}
)
seen_sub_idx
.
add
(
k
)
seen_idx
.
add
(
k
)
for
i
in
range
(
len
(
subjects
)):
if
i
in
seen_sub_idx
:
continue
ret
.
append
(
{
'sub_bbox'
:
{
'bbox'
:
subjects
[
i
][
'bbox'
],
'score'
:
subjects
[
i
][
'score'
],
},
'obj_bboxes'
:
[],
'sub_idx'
:
i
,
}
)
return
ret
def
get_imgs
(
self
):
def
get_imgs
(
self
):
with_captions
=
self
.
__tie_up_category_by_distance_v3
(
with_captions
=
self
.
__tie_up_category_by_distance_v3
(
CategoryId
.
ImageBody
,
CategoryId
.
ImageCaption
CategoryId
.
ImageBody
,
CategoryId
.
ImageCaption
...
...
mineru/backend/vlm/vlm_magic_model.py
View file @
9bf8be98
...
@@ -3,11 +3,10 @@ from typing import Literal
...
@@ -3,11 +3,10 @@ from typing import Literal
from
loguru
import
logger
from
loguru
import
logger
from
mineru.utils.boxbase
import
bbox_distance
,
is_in
from
mineru.utils.enum_class
import
ContentType
,
BlockType
,
SplitFlag
from
mineru.utils.enum_class
import
ContentType
,
BlockType
,
SplitFlag
from
mineru.backend.vlm.vlm_middle_json_mkcontent
import
merge_para_with_text
from
mineru.backend.vlm.vlm_middle_json_mkcontent
import
merge_para_with_text
from
mineru.utils.format_utils
import
convert_otsl_to_html
from
mineru.utils.format_utils
import
convert_otsl_to_html
import
mineru.utils.magic_model_utils
as
magic_model_utils
class
MagicModel
:
class
MagicModel
:
def
__init__
(
self
,
token
:
str
,
width
,
height
):
def
__init__
(
self
,
token
:
str
,
width
,
height
):
...
@@ -251,179 +250,19 @@ def latex_fix(latex):
...
@@ -251,179 +250,19 @@ def latex_fix(latex):
return
latex
return
latex
def
__reduct_overlap
(
bboxes
):
N
=
len
(
bboxes
)
keep
=
[
True
]
*
N
for
i
in
range
(
N
):
for
j
in
range
(
N
):
if
i
==
j
:
continue
if
is_in
(
bboxes
[
i
][
"bbox"
],
bboxes
[
j
][
"bbox"
]):
keep
[
i
]
=
False
return
[
bboxes
[
i
]
for
i
in
range
(
N
)
if
keep
[
i
]]
def
__tie_up_category_by_distance_v3
(
def
__tie_up_category_by_distance_v3
(
blocks
:
list
,
blocks
:
list
,
subject_block_type
:
str
,
subject_block_type
:
str
,
object_block_type
:
str
,
object_block_type
:
str
,
):
):
subjects
=
__reduct_overlap
(
return
magic_model_utils
.
tie_up_category_by_distance_v3
(
list
(
blocks
,
map
(
lambda
x
:
x
[
"type"
]
==
subject_block_type
,
lambda
x
:
{
"bbox"
:
x
[
"bbox"
],
"lines"
:
x
[
"lines"
],
"index"
:
x
[
"index"
]},
lambda
x
:
x
[
"type"
]
==
object_block_type
,
filter
(
extract_bbox_func
=
lambda
x
:
x
[
"bbox"
],
lambda
x
:
x
[
"type"
]
==
subject_block_type
,
extract_score_func
=
lambda
x
:
x
.
get
(
"score"
,
1.0
),
blocks
,
create_item_func
=
lambda
x
:
{
"bbox"
:
x
[
"bbox"
],
"lines"
:
x
[
"lines"
],
"index"
:
x
[
"index"
]}
),
)
)
)
)
objects
=
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
"bbox"
:
x
[
"bbox"
],
"lines"
:
x
[
"lines"
],
"index"
:
x
[
"index"
]},
filter
(
lambda
x
:
x
[
"type"
]
==
object_block_type
,
blocks
,
),
)
)
)
ret
=
[]
N
,
M
=
len
(
subjects
),
len
(
objects
)
subjects
.
sort
(
key
=
lambda
x
:
x
[
"bbox"
][
0
]
**
2
+
x
[
"bbox"
][
1
]
**
2
)
objects
.
sort
(
key
=
lambda
x
:
x
[
"bbox"
][
0
]
**
2
+
x
[
"bbox"
][
1
]
**
2
)
OBJ_IDX_OFFSET
=
10000
SUB_BIT_KIND
,
OBJ_BIT_KIND
=
0
,
1
all_boxes_with_idx
=
[(
i
,
SUB_BIT_KIND
,
sub
[
"bbox"
][
0
],
sub
[
"bbox"
][
1
])
for
i
,
sub
in
enumerate
(
subjects
)]
+
[
(
i
+
OBJ_IDX_OFFSET
,
OBJ_BIT_KIND
,
obj
[
"bbox"
][
0
],
obj
[
"bbox"
][
1
])
for
i
,
obj
in
enumerate
(
objects
)
]
seen_idx
=
set
()
seen_sub_idx
=
set
()
while
N
>
len
(
seen_sub_idx
):
candidates
=
[]
for
idx
,
kind
,
x0
,
y0
in
all_boxes_with_idx
:
if
idx
in
seen_idx
:
continue
candidates
.
append
((
idx
,
kind
,
x0
,
y0
))
if
len
(
candidates
)
==
0
:
break
left_x
=
min
([
v
[
2
]
for
v
in
candidates
])
top_y
=
min
([
v
[
3
]
for
v
in
candidates
])
candidates
.
sort
(
key
=
lambda
x
:
(
x
[
2
]
-
left_x
)
**
2
+
(
x
[
3
]
-
top_y
)
**
2
)
fst_idx
,
fst_kind
,
left_x
,
top_y
=
candidates
[
0
]
fst_bbox
=
subjects
[
fst_idx
][
'bbox'
]
if
fst_kind
==
SUB_BIT_KIND
else
objects
[
fst_idx
-
OBJ_IDX_OFFSET
][
'bbox'
]
candidates
.
sort
(
key
=
lambda
x
:
bbox_distance
(
fst_bbox
,
subjects
[
x
[
0
]][
'bbox'
])
if
x
[
1
]
==
SUB_BIT_KIND
else
bbox_distance
(
fst_bbox
,
objects
[
x
[
0
]
-
OBJ_IDX_OFFSET
][
'bbox'
]))
nxt
=
None
for
i
in
range
(
1
,
len
(
candidates
)):
if
candidates
[
i
][
1
]
^
fst_kind
==
1
:
nxt
=
candidates
[
i
]
break
if
nxt
is
None
:
break
if
fst_kind
==
SUB_BIT_KIND
:
sub_idx
,
obj_idx
=
fst_idx
,
nxt
[
0
]
-
OBJ_IDX_OFFSET
else
:
sub_idx
,
obj_idx
=
nxt
[
0
],
fst_idx
-
OBJ_IDX_OFFSET
pair_dis
=
bbox_distance
(
subjects
[
sub_idx
][
"bbox"
],
objects
[
obj_idx
][
"bbox"
])
nearest_dis
=
float
(
"inf"
)
for
i
in
range
(
N
):
# 取消原先算法中 1对1 匹配的偏置
# if i in seen_idx or i == sub_idx:
# continue
nearest_dis
=
min
(
nearest_dis
,
bbox_distance
(
subjects
[
i
][
"bbox"
],
objects
[
obj_idx
][
"bbox"
]))
if
pair_dis
>=
3
*
nearest_dis
:
seen_idx
.
add
(
sub_idx
)
continue
seen_idx
.
add
(
sub_idx
)
seen_idx
.
add
(
obj_idx
+
OBJ_IDX_OFFSET
)
seen_sub_idx
.
add
(
sub_idx
)
ret
.
append
(
{
"sub_bbox"
:
{
"bbox"
:
subjects
[
sub_idx
][
"bbox"
],
"lines"
:
subjects
[
sub_idx
][
"lines"
],
"index"
:
subjects
[
sub_idx
][
"index"
],
},
"obj_bboxes"
:
[
{
"bbox"
:
objects
[
obj_idx
][
"bbox"
],
"lines"
:
objects
[
obj_idx
][
"lines"
],
"index"
:
objects
[
obj_idx
][
"index"
]}
],
"sub_idx"
:
sub_idx
,
}
)
for
i
in
range
(
len
(
objects
)):
j
=
i
+
OBJ_IDX_OFFSET
if
j
in
seen_idx
:
continue
seen_idx
.
add
(
j
)
nearest_dis
,
nearest_sub_idx
=
float
(
"inf"
),
-
1
for
k
in
range
(
len
(
subjects
)):
dis
=
bbox_distance
(
objects
[
i
][
"bbox"
],
subjects
[
k
][
"bbox"
])
if
dis
<
nearest_dis
:
nearest_dis
=
dis
nearest_sub_idx
=
k
for
k
in
range
(
len
(
subjects
)):
if
k
!=
nearest_sub_idx
:
continue
if
k
in
seen_sub_idx
:
for
kk
in
range
(
len
(
ret
)):
if
ret
[
kk
][
"sub_idx"
]
==
k
:
ret
[
kk
][
"obj_bboxes"
].
append
(
{
"bbox"
:
objects
[
i
][
"bbox"
],
"lines"
:
objects
[
i
][
"lines"
],
"index"
:
objects
[
i
][
"index"
]}
)
break
else
:
ret
.
append
(
{
"sub_bbox"
:
{
"bbox"
:
subjects
[
k
][
"bbox"
],
"lines"
:
subjects
[
k
][
"lines"
],
"index"
:
subjects
[
k
][
"index"
],
},
"obj_bboxes"
:
[
{
"bbox"
:
objects
[
i
][
"bbox"
],
"lines"
:
objects
[
i
][
"lines"
],
"index"
:
objects
[
i
][
"index"
]}
],
"sub_idx"
:
k
,
}
)
seen_sub_idx
.
add
(
k
)
seen_idx
.
add
(
k
)
for
i
in
range
(
len
(
subjects
)):
if
i
in
seen_sub_idx
:
continue
ret
.
append
(
{
"sub_bbox"
:
{
"bbox"
:
subjects
[
i
][
"bbox"
],
"lines"
:
subjects
[
i
][
"lines"
],
"index"
:
subjects
[
i
][
"index"
],
},
"obj_bboxes"
:
[],
"sub_idx"
:
i
,
}
)
return
ret
def
get_type_blocks
(
blocks
,
block_type
:
Literal
[
"image"
,
"table"
]):
def
get_type_blocks
(
blocks
,
block_type
:
Literal
[
"image"
,
"table"
]):
...
...
mineru/utils/magic_model_utils.py
0 → 100644
View file @
9bf8be98
"""
布局处理的公共工具类
包含两个MagicModel类中重复使用的方法和逻辑
"""
from
typing
import
List
,
Dict
,
Any
,
Union
from
mineru.utils.boxbase
import
bbox_relative_pos
,
calculate_iou
,
bbox_distance
,
is_in
def
reduct_overlap
(
bboxes
:
List
[
Dict
[
str
,
Any
]])
->
List
[
Dict
[
str
,
Any
]]:
"""
去除重叠的bbox,保留不被其他bbox包含的bbox
Args:
bboxes: 包含bbox信息的字典列表
Returns:
去重后的bbox列表
"""
N
=
len
(
bboxes
)
keep
=
[
True
]
*
N
for
i
in
range
(
N
):
for
j
in
range
(
N
):
if
i
==
j
:
continue
if
is_in
(
bboxes
[
i
][
'bbox'
],
bboxes
[
j
][
'bbox'
]):
keep
[
i
]
=
False
return
[
bboxes
[
i
]
for
i
in
range
(
N
)
if
keep
[
i
]]
def
bbox_distance_with_relative_check
(
bbox1
:
List
[
int
],
bbox2
:
List
[
int
])
->
float
:
"""
计算两个bbox之间的距离,考虑相对位置约束
Args:
bbox1: 第一个bbox [x1, y1, x2, y2]
bbox2: 第二个bbox [x1, y1, x2, y2]
Returns:
距离值,如果不满足条件返回无穷大
"""
left
,
right
,
bottom
,
top
=
bbox_relative_pos
(
bbox1
,
bbox2
)
flags
=
[
left
,
right
,
bottom
,
top
]
count
=
sum
([
1
if
v
else
0
for
v
in
flags
])
if
count
>
1
:
return
float
(
'inf'
)
if
left
or
right
:
l1
=
bbox1
[
3
]
-
bbox1
[
1
]
l2
=
bbox2
[
3
]
-
bbox2
[
1
]
else
:
l1
=
bbox1
[
2
]
-
bbox1
[
0
]
l2
=
bbox2
[
2
]
-
bbox2
[
0
]
if
l2
>
l1
and
(
l2
-
l1
)
/
l1
>
0.3
:
return
float
(
'inf'
)
return
bbox_distance
(
bbox1
,
bbox2
)
def
tie_up_category_by_distance_v3
(
data_source
:
Union
[
List
[
Dict
],
Dict
],
subject_category_filter
,
object_category_filter
,
extract_bbox_func
=
None
,
extract_score_func
=
None
,
create_item_func
=
None
)
->
List
[
Dict
[
str
,
Any
]]:
"""
基于距离关联不同类型的区块/元素
Args:
data_source: 数据源,可以是列表或包含layout_dets的字典
subject_category_filter: 主体类别过滤函数或值
object_category_filter: 对象类别过滤函数或值
extract_bbox_func: 提取bbox的函数,默认使用'bbox'键
extract_score_func: 提取score的函数,默认使用'score'键
create_item_func: 创建返回项的函数
Returns:
关联结果列表
"""
# 默认函数
if
extract_bbox_func
is
None
:
extract_bbox_func
=
lambda
x
:
x
[
'bbox'
]
if
extract_score_func
is
None
:
extract_score_func
=
lambda
x
:
x
[
'score'
]
if
create_item_func
is
None
:
create_item_func
=
lambda
x
:
{
'bbox'
:
extract_bbox_func
(
x
),
'score'
:
extract_score_func
(
x
)}
# 处理数据源
if
isinstance
(
data_source
,
dict
)
and
'layout_dets'
in
data_source
:
items
=
data_source
[
'layout_dets'
]
else
:
items
=
data_source
# 过滤主体和对象
if
callable
(
subject_category_filter
):
subjects
=
list
(
filter
(
subject_category_filter
,
items
))
else
:
subjects
=
list
(
filter
(
lambda
x
:
x
.
get
(
'category_id'
)
==
subject_category_filter
or
x
.
get
(
'type'
)
==
subject_category_filter
,
items
))
if
callable
(
object_category_filter
):
objects
=
list
(
filter
(
object_category_filter
,
items
))
else
:
objects
=
list
(
filter
(
lambda
x
:
x
.
get
(
'category_id'
)
==
object_category_filter
or
x
.
get
(
'type'
)
==
object_category_filter
,
items
))
# 转换为标准格式并去重
subjects
=
reduct_overlap
([
create_item_func
(
x
)
for
x
in
subjects
])
objects
=
reduct_overlap
([
create_item_func
(
x
)
for
x
in
objects
])
ret
=
[]
N
,
M
=
len
(
subjects
),
len
(
objects
)
subjects
.
sort
(
key
=
lambda
x
:
extract_bbox_func
(
x
)[
0
]
**
2
+
extract_bbox_func
(
x
)[
1
]
**
2
)
objects
.
sort
(
key
=
lambda
x
:
extract_bbox_func
(
x
)[
0
]
**
2
+
extract_bbox_func
(
x
)[
1
]
**
2
)
OBJ_IDX_OFFSET
=
10000
SUB_BIT_KIND
,
OBJ_BIT_KIND
=
0
,
1
all_boxes_with_idx
=
[(
i
,
SUB_BIT_KIND
,
extract_bbox_func
(
sub
)[
0
],
extract_bbox_func
(
sub
)[
1
])
for
i
,
sub
in
enumerate
(
subjects
)]
+
\
[(
i
+
OBJ_IDX_OFFSET
,
OBJ_BIT_KIND
,
extract_bbox_func
(
obj
)[
0
],
extract_bbox_func
(
obj
)[
1
])
for
i
,
obj
in
enumerate
(
objects
)]
seen_idx
=
set
()
seen_sub_idx
=
set
()
while
N
>
len
(
seen_sub_idx
):
candidates
=
[]
for
idx
,
kind
,
x0
,
y0
in
all_boxes_with_idx
:
if
idx
in
seen_idx
:
continue
candidates
.
append
((
idx
,
kind
,
x0
,
y0
))
if
len
(
candidates
)
==
0
:
break
left_x
=
min
([
v
[
2
]
for
v
in
candidates
])
top_y
=
min
([
v
[
3
]
for
v
in
candidates
])
candidates
.
sort
(
key
=
lambda
x
:
(
x
[
2
]
-
left_x
)
**
2
+
(
x
[
3
]
-
top_y
)
**
2
)
fst_idx
,
fst_kind
,
left_x
,
top_y
=
candidates
[
0
]
fst_bbox
=
extract_bbox_func
(
subjects
[
fst_idx
])
if
fst_kind
==
SUB_BIT_KIND
else
extract_bbox_func
(
objects
[
fst_idx
-
OBJ_IDX_OFFSET
])
candidates
.
sort
(
key
=
lambda
x
:
bbox_distance
(
fst_bbox
,
extract_bbox_func
(
subjects
[
x
[
0
]]))
if
x
[
1
]
==
SUB_BIT_KIND
else
bbox_distance
(
fst_bbox
,
extract_bbox_func
(
objects
[
x
[
0
]
-
OBJ_IDX_OFFSET
])))
nxt
=
None
for
i
in
range
(
1
,
len
(
candidates
)):
if
candidates
[
i
][
1
]
^
fst_kind
==
1
:
nxt
=
candidates
[
i
]
break
if
nxt
is
None
:
break
if
fst_kind
==
SUB_BIT_KIND
:
sub_idx
,
obj_idx
=
fst_idx
,
nxt
[
0
]
-
OBJ_IDX_OFFSET
else
:
sub_idx
,
obj_idx
=
nxt
[
0
],
fst_idx
-
OBJ_IDX_OFFSET
pair_dis
=
bbox_distance
(
extract_bbox_func
(
subjects
[
sub_idx
]),
extract_bbox_func
(
objects
[
obj_idx
]))
nearest_dis
=
float
(
'inf'
)
for
i
in
range
(
N
):
# 取消原先算法中 1对1 匹配的偏置
# if i in seen_idx or i == sub_idx:continue
nearest_dis
=
min
(
nearest_dis
,
bbox_distance
(
extract_bbox_func
(
subjects
[
i
]),
extract_bbox_func
(
objects
[
obj_idx
])))
if
pair_dis
>=
3
*
nearest_dis
:
seen_idx
.
add
(
sub_idx
)
continue
seen_idx
.
add
(
sub_idx
)
seen_idx
.
add
(
obj_idx
+
OBJ_IDX_OFFSET
)
seen_sub_idx
.
add
(
sub_idx
)
ret
.
append
({
'sub_bbox'
:
subjects
[
sub_idx
],
'obj_bboxes'
:
[
objects
[
obj_idx
]],
'sub_idx'
:
sub_idx
,
})
# 处理剩余的对象
for
i
in
range
(
len
(
objects
)):
j
=
i
+
OBJ_IDX_OFFSET
if
j
in
seen_idx
:
continue
seen_idx
.
add
(
j
)
nearest_dis
,
nearest_sub_idx
=
float
(
'inf'
),
-
1
for
k
in
range
(
len
(
subjects
)):
dis
=
bbox_distance
(
extract_bbox_func
(
objects
[
i
]),
extract_bbox_func
(
subjects
[
k
]))
if
dis
<
nearest_dis
:
nearest_dis
=
dis
nearest_sub_idx
=
k
for
k
in
range
(
len
(
subjects
)):
if
k
!=
nearest_sub_idx
:
continue
if
k
in
seen_sub_idx
:
for
kk
in
range
(
len
(
ret
)):
if
ret
[
kk
][
'sub_idx'
]
==
k
:
ret
[
kk
][
'obj_bboxes'
].
append
(
objects
[
i
])
break
else
:
ret
.
append
({
'sub_bbox'
:
subjects
[
k
],
'obj_bboxes'
:
[
objects
[
i
]],
'sub_idx'
:
k
,
})
seen_sub_idx
.
add
(
k
)
seen_idx
.
add
(
k
)
# 处理剩余的主体
for
i
in
range
(
len
(
subjects
)):
if
i
in
seen_sub_idx
:
continue
ret
.
append
({
'sub_bbox'
:
subjects
[
i
],
'obj_bboxes'
:
[],
'sub_idx'
:
i
,
})
return
ret
def
remove_high_iou_low_confidence
(
layout_dets
:
List
[
Dict
],
iou_threshold
:
float
=
0.9
):
"""
删除高IOU且置信度较低的检测结果
Args:
layout_dets: 布局检测结果列表
iou_threshold: IOU阈值
"""
need_remove_list
=
[]
for
i
in
range
(
len
(
layout_dets
)):
for
j
in
range
(
i
+
1
,
len
(
layout_dets
)):
layout_det1
=
layout_dets
[
i
]
layout_det2
=
layout_dets
[
j
]
if
calculate_iou
(
layout_det1
[
'bbox'
],
layout_det2
[
'bbox'
])
>
iou_threshold
:
layout_det_need_remove
=
layout_det1
if
layout_det1
[
'score'
]
<
layout_det2
[
'score'
]
else
layout_det2
if
layout_det_need_remove
not
in
need_remove_list
:
need_remove_list
.
append
(
layout_det_need_remove
)
for
need_remove
in
need_remove_list
:
if
need_remove
in
layout_dets
:
layout_dets
.
remove
(
need_remove
)
def
remove_low_confidence
(
layout_dets
:
List
[
Dict
],
confidence_threshold
:
float
=
0.05
):
"""
删除置信度特别低的检测结果
Args:
layout_dets: 布局检测结果列表
confidence_threshold: 置信度阈值
"""
need_remove_list
=
[]
for
layout_det
in
layout_dets
:
if
layout_det
[
'score'
]
<=
confidence_threshold
:
need_remove_list
.
append
(
layout_det
)
for
need_remove
in
need_remove_list
:
if
need_remove
in
layout_dets
:
layout_dets
.
remove
(
need_remove
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment