Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
1906643c
Commit
1906643c
authored
Jul 23, 2025
by
myhloli
Browse files
refactor: streamline bbox processing and enhance category tying logic in magic_model_utils.py
parent
ee6d557f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
173 additions
and
186 deletions
+173
-186
mineru/backend/pipeline/pipeline_magic_model.py
mineru/backend/pipeline/pipeline_magic_model.py
+78
-18
mineru/backend/vlm/vlm_magic_model.py
mineru/backend/vlm/vlm_magic_model.py
+34
-13
mineru/utils/magic_model_utils.py
mineru/utils/magic_model_utils.py
+61
-155
No files found.
mineru/backend/pipeline/pipeline_magic_model.py
View file @
1906643c
from
mineru.utils.boxbase
import
bbox_relative_pos
,
get_minbox_if_overlap_by_ratio
from
mineru.utils.boxbase
import
bbox_relative_pos
,
calculate_iou
,
bbox_distance
,
is_in
,
get_minbox_if_overlap_by_ratio
from
mineru.utils.enum_class
import
CategoryId
,
ContentType
from
mineru.utils.enum_class
import
CategoryId
,
ContentType
import
mineru.utils.magic_model_utils
as
magic_model_utils
from
mineru.utils.magic_model_utils
import
tie_up_category_by_distance_v3
,
reduct_overlap
class
MagicModel
:
class
MagicModel
:
...
@@ -90,9 +90,18 @@ class MagicModel:
...
@@ -90,9 +90,18 @@ class MagicModel:
layout_dets
.
remove
(
need_remove
)
layout_dets
.
remove
(
need_remove
)
def
__fix_by_remove_low_confidence
(
self
):
def
__fix_by_remove_low_confidence
(
self
):
magic_model_utils
.
remove_low_confidence
(
self
.
__page_model_info
[
'layout_dets'
])
need_remove_list
=
[]
layout_dets
=
self
.
__page_model_info
[
'layout_dets'
]
for
layout_det
in
layout_dets
:
if
layout_det
[
'score'
]
<=
0.05
:
need_remove_list
.
append
(
layout_det
)
else
:
continue
for
need_remove
in
need_remove_list
:
layout_dets
.
remove
(
need_remove
)
def
__fix_by_remove_high_iou_and_low_confidence
(
self
):
def
__fix_by_remove_high_iou_and_low_confidence
(
self
):
need_remove_list
=
[]
layout_dets
=
list
(
filter
(
layout_dets
=
list
(
filter
(
lambda
x
:
x
[
'category_id'
]
in
[
lambda
x
:
x
[
'category_id'
]
in
[
CategoryId
.
Title
,
CategoryId
.
Title
,
...
@@ -107,7 +116,20 @@ class MagicModel:
...
@@ -107,7 +116,20 @@ class MagicModel:
],
self
.
__page_model_info
[
'layout_dets'
]
],
self
.
__page_model_info
[
'layout_dets'
]
)
)
)
)
magic_model_utils
.
remove_high_iou_low_confidence
(
layout_dets
)
for
i
in
range
(
len
(
layout_dets
)):
for
j
in
range
(
i
+
1
,
len
(
layout_dets
)):
layout_det1
=
layout_dets
[
i
]
layout_det2
=
layout_dets
[
j
]
if
calculate_iou
(
layout_det1
[
'bbox'
],
layout_det2
[
'bbox'
])
>
0.9
:
layout_det_need_remove
=
layout_det1
if
layout_det1
[
'score'
]
<
layout_det2
[
'score'
]
else
layout_det2
if
layout_det_need_remove
not
in
need_remove_list
:
need_remove_list
.
append
(
layout_det_need_remove
)
for
need_remove
in
need_remove_list
:
self
.
__page_model_info
[
'layout_dets'
].
remove
(
need_remove
)
def
__fix_footnote
(
self
):
def
__fix_footnote
(
self
):
footnotes
=
[]
footnotes
=
[]
...
@@ -141,7 +163,7 @@ class MagicModel:
...
@@ -141,7 +163,7 @@ class MagicModel:
if
pos_flag_count
>
1
:
if
pos_flag_count
>
1
:
continue
continue
dis_figure_footnote
[
i
]
=
min
(
dis_figure_footnote
[
i
]
=
min
(
magic_model_utils
.
bbox_distance_with_relative_check
(
figures
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]),
self
.
_bbox_distance
(
figures
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]),
dis_figure_footnote
.
get
(
i
,
float
(
'inf'
)),
dis_figure_footnote
.
get
(
i
,
float
(
'inf'
)),
)
)
for
i
in
range
(
len
(
footnotes
)):
for
i
in
range
(
len
(
footnotes
)):
...
@@ -160,7 +182,7 @@ class MagicModel:
...
@@ -160,7 +182,7 @@ class MagicModel:
continue
continue
dis_table_footnote
[
i
]
=
min
(
dis_table_footnote
[
i
]
=
min
(
magic_model_utils
.
bbox_distance_with_relative_check
(
tables
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]),
self
.
_bbox_distance
(
tables
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]),
dis_table_footnote
.
get
(
i
,
float
(
'inf'
)),
dis_table_footnote
.
get
(
i
,
float
(
'inf'
)),
)
)
for
i
in
range
(
len
(
footnotes
)):
for
i
in
range
(
len
(
footnotes
)):
...
@@ -169,18 +191,56 @@ class MagicModel:
...
@@ -169,18 +191,56 @@ class MagicModel:
if
dis_table_footnote
.
get
(
i
,
float
(
'inf'
))
>
dis_figure_footnote
[
i
]:
if
dis_table_footnote
.
get
(
i
,
float
(
'inf'
))
>
dis_figure_footnote
[
i
]:
footnotes
[
i
][
'category_id'
]
=
CategoryId
.
ImageFootnote
footnotes
[
i
][
'category_id'
]
=
CategoryId
.
ImageFootnote
def
__tie_up_category_by_distance_v3
(
def
_bbox_distance
(
self
,
bbox1
,
bbox2
):
self
,
left
,
right
,
bottom
,
top
=
bbox_relative_pos
(
bbox1
,
bbox2
)
subject_category_id
:
int
,
flags
=
[
left
,
right
,
bottom
,
top
]
object_category_id
:
int
,
count
=
sum
([
1
if
v
else
0
for
v
in
flags
])
):
if
count
>
1
:
return
magic_model_utils
.
tie_up_category_by_distance_v3
(
return
float
(
'inf'
)
self
.
__page_model_info
,
if
left
or
right
:
subject_category_id
,
l1
=
bbox1
[
3
]
-
bbox1
[
1
]
object_category_id
,
l2
=
bbox2
[
3
]
-
bbox2
[
1
]
extract_bbox_func
=
lambda
x
:
x
[
'bbox'
],
else
:
extract_score_func
=
lambda
x
:
x
[
'score'
],
l1
=
bbox1
[
2
]
-
bbox1
[
0
]
create_item_func
=
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]}
l2
=
bbox2
[
2
]
-
bbox2
[
0
]
if
l2
>
l1
and
(
l2
-
l1
)
/
l1
>
0.3
:
return
float
(
'inf'
)
return
bbox_distance
(
bbox1
,
bbox2
)
def
__tie_up_category_by_distance_v3
(
self
,
subject_category_id
,
object_category_id
):
# 定义获取主体和客体对象的函数
def
get_subjects
():
return
reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
subject_category_id
,
self
.
__page_model_info
[
'layout_dets'
],
),
)
)
)
def
get_objects
():
return
reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
object_category_id
,
self
.
__page_model_info
[
'layout_dets'
],
),
)
)
)
# 调用通用方法
return
tie_up_category_by_distance_v3
(
get_subjects
,
get_objects
)
)
def
get_imgs
(
self
):
def
get_imgs
(
self
):
...
...
mineru/backend/vlm/vlm_magic_model.py
View file @
1906643c
...
@@ -6,7 +6,8 @@ from loguru import logger
...
@@ -6,7 +6,8 @@ from loguru import logger
from
mineru.utils.enum_class
import
ContentType
,
BlockType
,
SplitFlag
from
mineru.utils.enum_class
import
ContentType
,
BlockType
,
SplitFlag
from
mineru.backend.vlm.vlm_middle_json_mkcontent
import
merge_para_with_text
from
mineru.backend.vlm.vlm_middle_json_mkcontent
import
merge_para_with_text
from
mineru.utils.format_utils
import
convert_otsl_to_html
from
mineru.utils.format_utils
import
convert_otsl_to_html
import
mineru.utils.magic_model_utils
as
magic_model_utils
from
mineru.utils.magic_model_utils
import
reduct_overlap
,
tie_up_category_by_distance_v3
class
MagicModel
:
class
MagicModel
:
def
__init__
(
self
,
token
:
str
,
width
,
height
):
def
__init__
(
self
,
token
:
str
,
width
,
height
):
...
@@ -250,18 +251,38 @@ def latex_fix(latex):
...
@@ -250,18 +251,38 @@ def latex_fix(latex):
return
latex
return
latex
def
__tie_up_category_by_distance_v3
(
def
__tie_up_category_by_distance_v3
(
blocks
,
subject_block_type
,
object_block_type
):
blocks
:
list
,
# 定义获取主体和客体对象的函数
subject_block_type
:
str
,
def
get_subjects
():
object_block_type
:
str
,
return
reduct_overlap
(
):
list
(
return
magic_model_utils
.
tie_up_category_by_distance_v3
(
map
(
blocks
,
lambda
x
:
{
"bbox"
:
x
[
"bbox"
],
"lines"
:
x
[
"lines"
],
"index"
:
x
[
"index"
]},
filter
(
lambda
x
:
x
[
"type"
]
==
subject_block_type
,
lambda
x
:
x
[
"type"
]
==
subject_block_type
,
blocks
,
),
)
)
)
def
get_objects
():
return
reduct_overlap
(
list
(
map
(
lambda
x
:
{
"bbox"
:
x
[
"bbox"
],
"lines"
:
x
[
"lines"
],
"index"
:
x
[
"index"
]},
filter
(
lambda
x
:
x
[
"type"
]
==
object_block_type
,
lambda
x
:
x
[
"type"
]
==
object_block_type
,
extract_bbox_func
=
lambda
x
:
x
[
"bbox"
],
blocks
,
extract_score_func
=
lambda
x
:
x
.
get
(
"score"
,
1.0
),
),
create_item_func
=
lambda
x
:
{
"bbox"
:
x
[
"bbox"
],
"lines"
:
x
[
"lines"
],
"index"
:
x
[
"index"
]}
)
)
)
# 调用通用方法
return
tie_up_category_by_distance_v3
(
get_subjects
,
get_objects
)
)
...
...
mineru/utils/magic_model_utils.py
View file @
1906643c
"""
"""
布局处理的公共工具类
包含两个MagicModel类中重复使用的方法和逻辑
包含两个MagicModel类中重复使用的方法和逻辑
"""
"""
from
typing
import
List
,
Dict
,
Any
,
Union
from
typing
import
List
,
Dict
,
Any
,
Callable
from
mineru.utils.boxbase
import
bbox_relative_pos
,
calculate_iou
,
bbox_distance
,
is_in
from
mineru.utils.boxbase
import
bbox_distance
,
is_in
def
reduct_overlap
(
bboxes
:
List
[
Dict
[
str
,
Any
]])
->
List
[
Dict
[
str
,
Any
]]:
def
reduct_overlap
(
bboxes
:
List
[
Dict
[
str
,
Any
]])
->
List
[
Dict
[
str
,
Any
]]:
...
@@ -27,101 +26,48 @@ def reduct_overlap(bboxes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
...
@@ -27,101 +26,48 @@ def reduct_overlap(bboxes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
return
[
bboxes
[
i
]
for
i
in
range
(
N
)
if
keep
[
i
]]
return
[
bboxes
[
i
]
for
i
in
range
(
N
)
if
keep
[
i
]]
def
bbox_distance_with_relative_check
(
bbox1
:
List
[
int
],
bbox2
:
List
[
int
])
->
float
:
"""
计算两个bbox之间的距离,考虑相对位置约束
Args:
bbox1: 第一个bbox [x1, y1, x2, y2]
bbox2: 第二个bbox [x1, y1, x2, y2]
Returns:
距离值,如果不满足条件返回无穷大
"""
left
,
right
,
bottom
,
top
=
bbox_relative_pos
(
bbox1
,
bbox2
)
flags
=
[
left
,
right
,
bottom
,
top
]
count
=
sum
([
1
if
v
else
0
for
v
in
flags
])
if
count
>
1
:
return
float
(
'inf'
)
if
left
or
right
:
l1
=
bbox1
[
3
]
-
bbox1
[
1
]
l2
=
bbox2
[
3
]
-
bbox2
[
1
]
else
:
l1
=
bbox1
[
2
]
-
bbox1
[
0
]
l2
=
bbox2
[
2
]
-
bbox2
[
0
]
if
l2
>
l1
and
(
l2
-
l1
)
/
l1
>
0.3
:
return
float
(
'inf'
)
return
bbox_distance
(
bbox1
,
bbox2
)
def
tie_up_category_by_distance_v3
(
def
tie_up_category_by_distance_v3
(
data_source
:
Union
[
List
[
Dict
],
Dict
],
get_subjects_func
:
Callable
,
subject_category_filter
,
get_objects_func
:
Callable
,
object_category_filter
,
extract_subject_func
:
Callable
=
None
,
extract_bbox_func
=
None
,
extract_object_func
:
Callable
=
None
extract_score_func
=
None
,
):
create_item_func
=
None
)
->
List
[
Dict
[
str
,
Any
]]:
"""
"""
基于距离关联不同类型的区块/元素
通用的类别关联方法,用于将主体对象与客体对象进行关联
Args:
参数:
data_source: 数据源,可以是列表或包含layout_dets的字典
get_subjects_func: 函数,提取主体对象
subject_category_filter: 主体类别过滤函数或值
get_objects_func: 函数,提取客体对象
object_category_filter: 对象类别过滤函数或值
extract_subject_func: 函数,自定义提取主体属性(默认使用bbox和其他属性)
extract_bbox_func: 提取bbox的函数,默认使用'bbox'键
extract_object_func: 函数,自定义提取客体属性(默认使用bbox和其他属性)
extract_score_func: 提取score的函数,默认使用'score'键
create_item_func: 创建返回项的函数
Returns
:
返回
:
关联
结果
列表
关联
后的对象
列表
"""
"""
# 默认函数
subjects
=
get_subjects_func
()
if
extract_bbox_func
is
None
:
objects
=
get_objects_func
()
extract_bbox_func
=
lambda
x
:
x
[
'bbox'
]
if
extract_score_func
is
None
:
extract_score_func
=
lambda
x
:
x
[
'score'
]
if
create_item_func
is
None
:
create_item_func
=
lambda
x
:
{
'bbox'
:
extract_bbox_func
(
x
),
'score'
:
extract_score_func
(
x
)}
# 处理数据源
if
isinstance
(
data_source
,
dict
)
and
'layout_dets'
in
data_source
:
items
=
data_source
[
'layout_dets'
]
else
:
items
=
data_source
# 过滤主体和对象
if
callable
(
subject_category_filter
):
subjects
=
list
(
filter
(
subject_category_filter
,
items
))
else
:
subjects
=
list
(
filter
(
lambda
x
:
x
.
get
(
'category_id'
)
==
subject_category_filter
or
x
.
get
(
'type'
)
==
subject_category_filter
,
items
))
if
callable
(
object_category_filter
):
objects
=
list
(
filter
(
object_category_filter
,
items
))
else
:
objects
=
list
(
filter
(
lambda
x
:
x
.
get
(
'category_id'
)
==
object_category_filter
or
x
.
get
(
'type'
)
==
object_category_filter
,
items
))
# 转换为标准格式并去重
# 如果没有提供自定义提取函数,使用默认函数
subjects
=
reduct_overlap
([
create_item_func
(
x
)
for
x
in
subjects
])
if
extract_subject_func
is
None
:
objects
=
reduct_overlap
([
create_item_func
(
x
)
for
x
in
objects
])
extract_subject_func
=
lambda
x
:
x
if
extract_object_func
is
None
:
extract_object_func
=
lambda
x
:
x
ret
=
[]
ret
=
[]
N
,
M
=
len
(
subjects
),
len
(
objects
)
N
,
M
=
len
(
subjects
),
len
(
objects
)
subjects
.
sort
(
key
=
lambda
x
:
extract_bbox_func
(
x
)[
0
]
**
2
+
extract_bbox_func
(
x
)
[
1
]
**
2
)
subjects
.
sort
(
key
=
lambda
x
:
x
[
"bbox"
][
0
]
**
2
+
x
[
"bbox"
]
[
1
]
**
2
)
objects
.
sort
(
key
=
lambda
x
:
extract_bbox_func
(
x
)[
0
]
**
2
+
extract_bbox_func
(
x
)
[
1
]
**
2
)
objects
.
sort
(
key
=
lambda
x
:
x
[
"bbox"
][
0
]
**
2
+
x
[
"bbox"
]
[
1
]
**
2
)
OBJ_IDX_OFFSET
=
10000
OBJ_IDX_OFFSET
=
10000
SUB_BIT_KIND
,
OBJ_BIT_KIND
=
0
,
1
SUB_BIT_KIND
,
OBJ_BIT_KIND
=
0
,
1
all_boxes_with_idx
=
[(
i
,
SUB_BIT_KIND
,
extract_bbox_func
(
sub
)[
0
],
extract_bbox_func
(
sub
)[
1
])
for
i
,
sub
in
enumerate
(
subjects
)]
+
\
all_boxes_with_idx
=
[(
i
,
SUB_BIT_KIND
,
sub
[
"bbox"
][
0
],
sub
[
"bbox"
][
1
])
for
i
,
sub
in
enumerate
(
subjects
)]
+
[
[(
i
+
OBJ_IDX_OFFSET
,
OBJ_BIT_KIND
,
extract_bbox_func
(
obj
)[
0
],
extract_bbox_func
(
obj
)[
1
])
for
i
,
obj
in
enumerate
(
objects
)]
(
i
+
OBJ_IDX_OFFSET
,
OBJ_BIT_KIND
,
obj
[
"bbox"
][
0
],
obj
[
"bbox"
][
1
])
for
i
,
obj
in
enumerate
(
objects
)
]
seen_idx
=
set
()
seen_idx
=
set
()
seen_sub_idx
=
set
()
seen_sub_idx
=
set
()
seen_sub_idx_len
=
len
(
seen_sub_idx
)
while
N
>
len
(
seen_sub_idx
):
while
N
>
seen_sub_idx_len
:
candidates
=
[]
candidates
=
[]
for
idx
,
kind
,
x0
,
y0
in
all_boxes_with_idx
:
for
idx
,
kind
,
x0
,
y0
in
all_boxes_with_idx
:
if
idx
in
seen_idx
:
if
idx
in
seen_idx
:
...
@@ -136,10 +82,10 @@ def tie_up_category_by_distance_v3(
...
@@ -136,10 +82,10 @@ def tie_up_category_by_distance_v3(
candidates
.
sort
(
key
=
lambda
x
:
(
x
[
2
]
-
left_x
)
**
2
+
(
x
[
3
]
-
top_y
)
**
2
)
candidates
.
sort
(
key
=
lambda
x
:
(
x
[
2
]
-
left_x
)
**
2
+
(
x
[
3
]
-
top_y
)
**
2
)
fst_idx
,
fst_kind
,
left_x
,
top_y
=
candidates
[
0
]
fst_idx
,
fst_kind
,
left_x
,
top_y
=
candidates
[
0
]
fst_bbox
=
extract_bbox_func
(
subjects
[
fst_idx
]
)
if
fst_kind
==
SUB_BIT_KIND
else
extract_bbox_func
(
objects
[
fst_idx
-
OBJ_IDX_OFFSET
]
)
fst_bbox
=
subjects
[
fst_idx
]
[
'bbox'
]
if
fst_kind
==
SUB_BIT_KIND
else
objects
[
fst_idx
-
OBJ_IDX_OFFSET
]
[
'bbox'
]
candidates
.
sort
(
candidates
.
sort
(
key
=
lambda
x
:
bbox_distance
(
fst_bbox
,
extract_bbox_func
(
subjects
[
x
[
0
]]
)
)
if
x
[
1
]
==
SUB_BIT_KIND
else
bbox_distance
(
key
=
lambda
x
:
bbox_distance
(
fst_bbox
,
subjects
[
x
[
0
]]
[
'bbox'
]
)
if
x
[
1
]
==
SUB_BIT_KIND
else
bbox_distance
(
fst_bbox
,
extract_bbox_func
(
objects
[
x
[
0
]
-
OBJ_IDX_OFFSET
]
)
))
fst_bbox
,
objects
[
x
[
0
]
-
OBJ_IDX_OFFSET
]
[
'bbox'
]
))
nxt
=
None
nxt
=
None
for
i
in
range
(
1
,
len
(
candidates
)):
for
i
in
range
(
1
,
len
(
candidates
)):
...
@@ -154,12 +100,12 @@ def tie_up_category_by_distance_v3(
...
@@ -154,12 +100,12 @@ def tie_up_category_by_distance_v3(
else
:
else
:
sub_idx
,
obj_idx
=
nxt
[
0
],
fst_idx
-
OBJ_IDX_OFFSET
sub_idx
,
obj_idx
=
nxt
[
0
],
fst_idx
-
OBJ_IDX_OFFSET
pair_dis
=
bbox_distance
(
extract_bbox_func
(
subjects
[
sub_idx
]),
extract_bbox_func
(
objects
[
obj_idx
]
)
)
pair_dis
=
bbox_distance
(
subjects
[
sub_idx
][
"bbox"
],
objects
[
obj_idx
]
[
"bbox"
]
)
nearest_dis
=
float
(
'
inf
'
)
nearest_dis
=
float
(
"
inf
"
)
for
i
in
range
(
N
):
for
i
in
range
(
N
):
# 取消原先算法中 1对1 匹配的偏置
# 取消原先算法中 1对1 匹配的偏置
# if i in seen_idx or i == sub_idx:continue
# if i in seen_idx or i == sub_idx:continue
nearest_dis
=
min
(
nearest_dis
,
bbox_distance
(
extract_bbox_func
(
subjects
[
i
]),
extract_bbox_func
(
objects
[
obj_idx
]
)
))
nearest_dis
=
min
(
nearest_dis
,
bbox_distance
(
subjects
[
i
][
"bbox"
],
objects
[
obj_idx
]
[
"bbox"
]
))
if
pair_dis
>=
3
*
nearest_dis
:
if
pair_dis
>=
3
*
nearest_dis
:
seen_idx
.
add
(
sub_idx
)
seen_idx
.
add
(
sub_idx
)
...
@@ -169,21 +115,22 @@ def tie_up_category_by_distance_v3(
...
@@ -169,21 +115,22 @@ def tie_up_category_by_distance_v3(
seen_idx
.
add
(
obj_idx
+
OBJ_IDX_OFFSET
)
seen_idx
.
add
(
obj_idx
+
OBJ_IDX_OFFSET
)
seen_sub_idx
.
add
(
sub_idx
)
seen_sub_idx
.
add
(
sub_idx
)
ret
.
append
({
ret
.
append
(
'sub_bbox'
:
subjects
[
sub_idx
],
{
'obj_bboxes'
:
[
objects
[
obj_idx
]],
"sub_bbox"
:
extract_subject_func
(
subjects
[
sub_idx
]),
'sub_idx'
:
sub_idx
,
"obj_bboxes"
:
[
extract_object_func
(
objects
[
obj_idx
])],
})
"sub_idx"
:
sub_idx
,
}
)
# 处理剩余的对象
for
i
in
range
(
len
(
objects
)):
for
i
in
range
(
len
(
objects
)):
j
=
i
+
OBJ_IDX_OFFSET
j
=
i
+
OBJ_IDX_OFFSET
if
j
in
seen_idx
:
if
j
in
seen_idx
:
continue
continue
seen_idx
.
add
(
j
)
seen_idx
.
add
(
j
)
nearest_dis
,
nearest_sub_idx
=
float
(
'
inf
'
),
-
1
nearest_dis
,
nearest_sub_idx
=
float
(
"
inf
"
),
-
1
for
k
in
range
(
len
(
subjects
)):
for
k
in
range
(
len
(
subjects
)):
dis
=
bbox_distance
(
extract_bbox_func
(
objects
[
i
]),
extract_bbox_func
(
subjects
[
k
]
)
)
dis
=
bbox_distance
(
objects
[
i
][
"bbox"
],
subjects
[
k
]
[
"bbox"
]
)
if
dis
<
nearest_dis
:
if
dis
<
nearest_dis
:
nearest_dis
=
dis
nearest_dis
=
dis
nearest_sub_idx
=
k
nearest_sub_idx
=
k
...
@@ -193,70 +140,29 @@ def tie_up_category_by_distance_v3(
...
@@ -193,70 +140,29 @@ def tie_up_category_by_distance_v3(
continue
continue
if
k
in
seen_sub_idx
:
if
k
in
seen_sub_idx
:
for
kk
in
range
(
len
(
ret
)):
for
kk
in
range
(
len
(
ret
)):
if
ret
[
kk
][
'
sub_idx
'
]
==
k
:
if
ret
[
kk
][
"
sub_idx
"
]
==
k
:
ret
[
kk
][
'
obj_bboxes
'
].
append
(
objects
[
i
])
ret
[
kk
][
"
obj_bboxes
"
].
append
(
extract_object_func
(
objects
[
i
])
)
break
break
else
:
else
:
ret
.
append
({
ret
.
append
(
'sub_bbox'
:
subjects
[
k
],
{
'obj_bboxes'
:
[
objects
[
i
]],
"sub_bbox"
:
extract_subject_func
(
subjects
[
k
]),
'sub_idx'
:
k
,
"obj_bboxes"
:
[
extract_object_func
(
objects
[
i
])],
})
"sub_idx"
:
k
,
}
)
seen_sub_idx
.
add
(
k
)
seen_sub_idx
.
add
(
k
)
seen_idx
.
add
(
k
)
seen_idx
.
add
(
k
)
# 处理剩余的主体
for
i
in
range
(
len
(
subjects
)):
for
i
in
range
(
len
(
subjects
)):
if
i
in
seen_sub_idx
:
if
i
in
seen_sub_idx
:
continue
continue
ret
.
append
({
ret
.
append
(
'sub_bbox'
:
subjects
[
i
],
{
'obj_bboxes'
:
[],
"sub_bbox"
:
extract_subject_func
(
subjects
[
i
]),
'sub_idx'
:
i
,
"obj_bboxes"
:
[],
})
"sub_idx"
:
i
,
}
)
return
ret
return
ret
\ No newline at end of file
def
remove_high_iou_low_confidence
(
layout_dets
:
List
[
Dict
],
iou_threshold
:
float
=
0.9
):
"""
删除高IOU且置信度较低的检测结果
Args:
layout_dets: 布局检测结果列表
iou_threshold: IOU阈值
"""
need_remove_list
=
[]
for
i
in
range
(
len
(
layout_dets
)):
for
j
in
range
(
i
+
1
,
len
(
layout_dets
)):
layout_det1
=
layout_dets
[
i
]
layout_det2
=
layout_dets
[
j
]
if
calculate_iou
(
layout_det1
[
'bbox'
],
layout_det2
[
'bbox'
])
>
iou_threshold
:
layout_det_need_remove
=
layout_det1
if
layout_det1
[
'score'
]
<
layout_det2
[
'score'
]
else
layout_det2
if
layout_det_need_remove
not
in
need_remove_list
:
need_remove_list
.
append
(
layout_det_need_remove
)
for
need_remove
in
need_remove_list
:
if
need_remove
in
layout_dets
:
layout_dets
.
remove
(
need_remove
)
def
remove_low_confidence
(
layout_dets
:
List
[
Dict
],
confidence_threshold
:
float
=
0.05
):
"""
删除置信度特别低的检测结果
Args:
layout_dets: 布局检测结果列表
confidence_threshold: 置信度阈值
"""
need_remove_list
=
[]
for
layout_det
in
layout_dets
:
if
layout_det
[
'score'
]
<=
confidence_threshold
:
need_remove_list
.
append
(
layout_det
)
for
need_remove
in
need_remove_list
:
if
need_remove
in
layout_dets
:
layout_dets
.
remove
(
need_remove
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment