Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
a5583ff4
Unverified
Commit
a5583ff4
authored
Jul 23, 2025
by
Xiaomeng Zhao
Committed by
GitHub
Jul 23, 2025
Browse files
fix: improve candidate sorting logic in vlm_magic_model.py
fix: improve candidate sorting logic in vlm_magic_model.py
parents
beacccb6
1906643c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
226 additions
and
324 deletions
+226
-324
mineru/backend/pipeline/pipeline_magic_model.py
mineru/backend/pipeline/pipeline_magic_model.py
+28
-158
mineru/backend/vlm/vlm_magic_model.py
mineru/backend/vlm/vlm_magic_model.py
+28
-164
mineru/model/mfr/unimernet/Unimernet.py
mineru/model/mfr/unimernet/Unimernet.py
+2
-2
mineru/utils/magic_model_utils.py
mineru/utils/magic_model_utils.py
+168
-0
No files found.
mineru/backend/pipeline/pipeline_magic_model.py
View file @
a5583ff4
from
mineru.utils.boxbase
import
bbox_relative_pos
,
calculate_iou
,
bbox_distance
,
is_in
,
get_minbox_if_overlap_by_ratio
from
mineru.utils.boxbase
import
bbox_relative_pos
,
calculate_iou
,
bbox_distance
,
is_in
,
get_minbox_if_overlap_by_ratio
from
mineru.utils.enum_class
import
CategoryId
,
ContentType
from
mineru.utils.enum_class
import
CategoryId
,
ContentType
from
mineru.utils.magic_model_utils
import
tie_up_category_by_distance_v3
,
reduct_overlap
class
MagicModel
:
class
MagicModel
:
...
@@ -208,170 +209,39 @@ class MagicModel:
...
@@ -208,170 +209,39 @@ class MagicModel:
return
bbox_distance
(
bbox1
,
bbox2
)
return
bbox_distance
(
bbox1
,
bbox2
)
def
__reduct_overlap
(
self
,
bboxes
):
def
__tie_up_category_by_distance_v3
(
self
,
subject_category_id
,
object_category_id
):
N
=
len
(
bboxes
)
# 定义获取主体和客体对象的函数
keep
=
[
True
]
*
N
def
get_subjects
():
for
i
in
range
(
N
):
return
reduct_overlap
(
for
j
in
range
(
N
):
list
(
if
i
==
j
:
map
(
continue
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
if
is_in
(
bboxes
[
i
][
'bbox'
],
bboxes
[
j
][
'bbox'
]):
filter
(
keep
[
i
]
=
False
lambda
x
:
x
[
'category_id'
]
==
subject_category_id
,
return
[
bboxes
[
i
]
for
i
in
range
(
N
)
if
keep
[
i
]]
self
.
__page_model_info
[
'layout_dets'
],
),
def
__tie_up_category_by_distance_v3
(
)
self
,
subject_category_id
:
int
,
object_category_id
:
int
,
):
subjects
=
self
.
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
subject_category_id
,
self
.
__page_model_info
[
'layout_dets'
],
),
)
)
)
objects
=
self
.
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
object_category_id
,
self
.
__page_model_info
[
'layout_dets'
],
),
)
)
)
)
)
ret
=
[]
N
,
M
=
len
(
subjects
),
len
(
objects
)
subjects
.
sort
(
key
=
lambda
x
:
x
[
'bbox'
][
0
]
**
2
+
x
[
'bbox'
][
1
]
**
2
)
objects
.
sort
(
key
=
lambda
x
:
x
[
'bbox'
][
0
]
**
2
+
x
[
'bbox'
][
1
]
**
2
)
OBJ_IDX_OFFSET
=
10000
SUB_BIT_KIND
,
OBJ_BIT_KIND
=
0
,
1
all_boxes_with_idx
=
[(
i
,
SUB_BIT_KIND
,
sub
[
'bbox'
][
0
],
sub
[
'bbox'
][
1
])
for
i
,
sub
in
enumerate
(
subjects
)]
+
[(
i
+
OBJ_IDX_OFFSET
,
OBJ_BIT_KIND
,
obj
[
'bbox'
][
0
],
obj
[
'bbox'
][
1
])
for
i
,
obj
in
enumerate
(
objects
)]
seen_idx
=
set
()
seen_sub_idx
=
set
()
while
N
>
len
(
seen_sub_idx
):
candidates
=
[]
for
idx
,
kind
,
x0
,
y0
in
all_boxes_with_idx
:
if
idx
in
seen_idx
:
continue
candidates
.
append
((
idx
,
kind
,
x0
,
y0
))
if
len
(
candidates
)
==
0
:
break
left_x
=
min
([
v
[
2
]
for
v
in
candidates
])
top_y
=
min
([
v
[
3
]
for
v
in
candidates
])
candidates
.
sort
(
key
=
lambda
x
:
(
x
[
2
]
-
left_x
)
**
2
+
(
x
[
3
]
-
top_y
)
**
2
)
fst_idx
,
fst_kind
,
left_x
,
top_y
=
candidates
[
0
]
fst_bbox
=
subjects
[
fst_idx
][
'bbox'
]
if
fst_kind
==
SUB_BIT_KIND
else
objects
[
fst_idx
-
OBJ_IDX_OFFSET
][
'bbox'
]
candidates
.
sort
(
key
=
lambda
x
:
bbox_distance
(
fst_bbox
,
subjects
[
x
[
0
]][
'bbox'
])
if
x
[
1
]
==
SUB_BIT_KIND
else
bbox_distance
(
fst_bbox
,
objects
[
x
[
0
]
-
OBJ_IDX_OFFSET
][
'bbox'
]))
nxt
=
None
for
i
in
range
(
1
,
len
(
candidates
)):
if
candidates
[
i
][
1
]
^
fst_kind
==
1
:
nxt
=
candidates
[
i
]
break
if
nxt
is
None
:
break
if
fst_kind
==
SUB_BIT_KIND
:
sub_idx
,
obj_idx
=
fst_idx
,
nxt
[
0
]
-
OBJ_IDX_OFFSET
else
:
def
get_objects
():
sub_idx
,
obj_idx
=
nxt
[
0
],
fst_idx
-
OBJ_IDX_OFFSET
return
reduct_overlap
(
list
(
pair_dis
=
bbox_distance
(
subjects
[
sub_idx
][
'bbox'
],
objects
[
obj_idx
][
'bbox'
])
map
(
nearest_dis
=
float
(
'inf'
)
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
for
i
in
range
(
N
):
filter
(
# 取消原先算法中 1对1 匹配的偏置
lambda
x
:
x
[
'category_id'
]
==
object_category_id
,
# if i in seen_idx or i == sub_idx:continue
self
.
__page_model_info
[
'layout_dets'
],
nearest_dis
=
min
(
nearest_dis
,
bbox_distance
(
subjects
[
i
][
'bbox'
],
objects
[
obj_idx
][
'bbox'
]))
),
if
pair_dis
>=
3
*
nearest_dis
:
seen_idx
.
add
(
sub_idx
)
continue
seen_idx
.
add
(
sub_idx
)
seen_idx
.
add
(
obj_idx
+
OBJ_IDX_OFFSET
)
seen_sub_idx
.
add
(
sub_idx
)
ret
.
append
(
{
'sub_bbox'
:
{
'bbox'
:
subjects
[
sub_idx
][
'bbox'
],
'score'
:
subjects
[
sub_idx
][
'score'
],
},
'obj_bboxes'
:
[
{
'score'
:
objects
[
obj_idx
][
'score'
],
'bbox'
:
objects
[
obj_idx
][
'bbox'
]}
],
'sub_idx'
:
sub_idx
,
}
)
for
i
in
range
(
len
(
objects
)):
j
=
i
+
OBJ_IDX_OFFSET
if
j
in
seen_idx
:
continue
seen_idx
.
add
(
j
)
nearest_dis
,
nearest_sub_idx
=
float
(
'inf'
),
-
1
for
k
in
range
(
len
(
subjects
)):
dis
=
bbox_distance
(
objects
[
i
][
'bbox'
],
subjects
[
k
][
'bbox'
])
if
dis
<
nearest_dis
:
nearest_dis
=
dis
nearest_sub_idx
=
k
for
k
in
range
(
len
(
subjects
)):
if
k
!=
nearest_sub_idx
:
continue
if
k
in
seen_sub_idx
:
for
kk
in
range
(
len
(
ret
)):
if
ret
[
kk
][
'sub_idx'
]
==
k
:
ret
[
kk
][
'obj_bboxes'
].
append
({
'score'
:
objects
[
i
][
'score'
],
'bbox'
:
objects
[
i
][
'bbox'
]})
break
else
:
ret
.
append
(
{
'sub_bbox'
:
{
'bbox'
:
subjects
[
k
][
'bbox'
],
'score'
:
subjects
[
k
][
'score'
],
},
'obj_bboxes'
:
[
{
'score'
:
objects
[
i
][
'score'
],
'bbox'
:
objects
[
i
][
'bbox'
]}
],
'sub_idx'
:
k
,
}
)
)
seen_sub_idx
.
add
(
k
)
)
seen_idx
.
add
(
k
)
for
i
in
range
(
len
(
subjects
)):
if
i
in
seen_sub_idx
:
continue
ret
.
append
(
{
'sub_bbox'
:
{
'bbox'
:
subjects
[
i
][
'bbox'
],
'score'
:
subjects
[
i
][
'score'
],
},
'obj_bboxes'
:
[],
'sub_idx'
:
i
,
}
)
)
# 调用通用方法
return
ret
return
tie_up_category_by_distance_v3
(
get_subjects
,
get_objects
)
def
get_imgs
(
self
):
def
get_imgs
(
self
):
with_captions
=
self
.
__tie_up_category_by_distance_v3
(
with_captions
=
self
.
__tie_up_category_by_distance_v3
(
...
...
mineru/backend/vlm/vlm_magic_model.py
View file @
a5583ff4
...
@@ -3,10 +3,10 @@ from typing import Literal
...
@@ -3,10 +3,10 @@ from typing import Literal
from
loguru
import
logger
from
loguru
import
logger
from
mineru.utils.boxbase
import
bbox_distance
,
is_in
from
mineru.utils.enum_class
import
ContentType
,
BlockType
,
SplitFlag
from
mineru.utils.enum_class
import
ContentType
,
BlockType
,
SplitFlag
from
mineru.backend.vlm.vlm_middle_json_mkcontent
import
merge_para_with_text
from
mineru.backend.vlm.vlm_middle_json_mkcontent
import
merge_para_with_text
from
mineru.utils.format_utils
import
convert_otsl_to_html
from
mineru.utils.format_utils
import
convert_otsl_to_html
from
mineru.utils.magic_model_utils
import
reduct_overlap
,
tie_up_category_by_distance_v3
class
MagicModel
:
class
MagicModel
:
...
@@ -251,175 +251,39 @@ def latex_fix(latex):
...
@@ -251,175 +251,39 @@ def latex_fix(latex):
return
latex
return
latex
def
__reduct_overlap
(
bboxes
):
def
__tie_up_category_by_distance_v3
(
blocks
,
subject_block_type
,
object_block_type
):
N
=
len
(
bboxes
)
# 定义获取主体和客体对象的函数
keep
=
[
True
]
*
N
def
get_subjects
():
for
i
in
range
(
N
):
return
reduct_overlap
(
for
j
in
range
(
N
):
list
(
if
i
==
j
:
map
(
continue
lambda
x
:
{
"bbox"
:
x
[
"bbox"
],
"lines"
:
x
[
"lines"
],
"index"
:
x
[
"index"
]},
if
is_in
(
bboxes
[
i
][
"bbox"
],
bboxes
[
j
][
"bbox"
]):
filter
(
keep
[
i
]
=
False
lambda
x
:
x
[
"type"
]
==
subject_block_type
,
return
[
bboxes
[
i
]
for
i
in
range
(
N
)
if
keep
[
i
]]
blocks
,
),
)
def
__tie_up_category_by_distance_v3
(
blocks
:
list
,
subject_block_type
:
str
,
object_block_type
:
str
,
):
subjects
=
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
"bbox"
:
x
[
"bbox"
],
"lines"
:
x
[
"lines"
],
"index"
:
x
[
"index"
]},
filter
(
lambda
x
:
x
[
"type"
]
==
subject_block_type
,
blocks
,
),
)
)
)
objects
=
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
"bbox"
:
x
[
"bbox"
],
"lines"
:
x
[
"lines"
],
"index"
:
x
[
"index"
]},
filter
(
lambda
x
:
x
[
"type"
]
==
object_block_type
,
blocks
,
),
)
)
)
)
)
ret
=
[]
def
get_objects
():
N
,
M
=
len
(
subjects
),
len
(
objects
)
return
reduct_overlap
(
subjects
.
sort
(
key
=
lambda
x
:
x
[
"bbox"
][
0
]
**
2
+
x
[
"bbox"
][
1
]
**
2
)
list
(
objects
.
sort
(
key
=
lambda
x
:
x
[
"bbox"
][
0
]
**
2
+
x
[
"bbox"
][
1
]
**
2
)
map
(
lambda
x
:
{
"bbox"
:
x
[
"bbox"
],
"lines"
:
x
[
"lines"
],
"index"
:
x
[
"index"
]},
OBJ_IDX_OFFSET
=
10000
filter
(
SUB_BIT_KIND
,
OBJ_BIT_KIND
=
0
,
1
lambda
x
:
x
[
"type"
]
==
object_block_type
,
blocks
,
all_boxes_with_idx
=
[(
i
,
SUB_BIT_KIND
,
sub
[
"bbox"
][
0
],
sub
[
"bbox"
][
1
])
for
i
,
sub
in
enumerate
(
subjects
)]
+
[
),
(
i
+
OBJ_IDX_OFFSET
,
OBJ_BIT_KIND
,
obj
[
"bbox"
][
0
],
obj
[
"bbox"
][
1
])
for
i
,
obj
in
enumerate
(
objects
)
]
seen_idx
=
set
()
seen_sub_idx
=
set
()
while
N
>
len
(
seen_sub_idx
):
candidates
=
[]
for
idx
,
kind
,
x0
,
y0
in
all_boxes_with_idx
:
if
idx
in
seen_idx
:
continue
candidates
.
append
((
idx
,
kind
,
x0
,
y0
))
if
len
(
candidates
)
==
0
:
break
left_x
=
min
([
v
[
2
]
for
v
in
candidates
])
top_y
=
min
([
v
[
3
]
for
v
in
candidates
])
candidates
.
sort
(
key
=
lambda
x
:
(
x
[
2
]
-
left_x
)
**
2
+
(
x
[
3
]
-
top_y
)
**
2
)
fst_idx
,
fst_kind
,
left_x
,
top_y
=
candidates
[
0
]
candidates
.
sort
(
key
=
lambda
x
:
(
x
[
2
]
-
left_x
)
**
2
+
(
x
[
3
]
-
top_y
)
**
2
)
nxt
=
None
for
i
in
range
(
1
,
len
(
candidates
)):
if
candidates
[
i
][
1
]
^
fst_kind
==
1
:
nxt
=
candidates
[
i
]
break
if
nxt
is
None
:
break
if
fst_kind
==
SUB_BIT_KIND
:
sub_idx
,
obj_idx
=
fst_idx
,
nxt
[
0
]
-
OBJ_IDX_OFFSET
else
:
sub_idx
,
obj_idx
=
nxt
[
0
],
fst_idx
-
OBJ_IDX_OFFSET
pair_dis
=
bbox_distance
(
subjects
[
sub_idx
][
"bbox"
],
objects
[
obj_idx
][
"bbox"
])
nearest_dis
=
float
(
"inf"
)
for
i
in
range
(
N
):
if
i
in
seen_idx
or
i
==
sub_idx
:
continue
nearest_dis
=
min
(
nearest_dis
,
bbox_distance
(
subjects
[
i
][
"bbox"
],
objects
[
obj_idx
][
"bbox"
]))
if
pair_dis
>=
3
*
nearest_dis
:
seen_idx
.
add
(
sub_idx
)
continue
seen_idx
.
add
(
sub_idx
)
seen_idx
.
add
(
obj_idx
+
OBJ_IDX_OFFSET
)
seen_sub_idx
.
add
(
sub_idx
)
ret
.
append
(
{
"sub_bbox"
:
{
"bbox"
:
subjects
[
sub_idx
][
"bbox"
],
"lines"
:
subjects
[
sub_idx
][
"lines"
],
"index"
:
subjects
[
sub_idx
][
"index"
],
},
"obj_bboxes"
:
[
{
"bbox"
:
objects
[
obj_idx
][
"bbox"
],
"lines"
:
objects
[
obj_idx
][
"lines"
],
"index"
:
objects
[
obj_idx
][
"index"
]}
],
"sub_idx"
:
sub_idx
,
}
)
for
i
in
range
(
len
(
objects
)):
j
=
i
+
OBJ_IDX_OFFSET
if
j
in
seen_idx
:
continue
seen_idx
.
add
(
j
)
nearest_dis
,
nearest_sub_idx
=
float
(
"inf"
),
-
1
for
k
in
range
(
len
(
subjects
)):
dis
=
bbox_distance
(
objects
[
i
][
"bbox"
],
subjects
[
k
][
"bbox"
])
if
dis
<
nearest_dis
:
nearest_dis
=
dis
nearest_sub_idx
=
k
for
k
in
range
(
len
(
subjects
)):
if
k
!=
nearest_sub_idx
:
continue
if
k
in
seen_sub_idx
:
for
kk
in
range
(
len
(
ret
)):
if
ret
[
kk
][
"sub_idx"
]
==
k
:
ret
[
kk
][
"obj_bboxes"
].
append
(
{
"bbox"
:
objects
[
i
][
"bbox"
],
"lines"
:
objects
[
i
][
"lines"
],
"index"
:
objects
[
i
][
"index"
]}
)
break
else
:
ret
.
append
(
{
"sub_bbox"
:
{
"bbox"
:
subjects
[
k
][
"bbox"
],
"lines"
:
subjects
[
k
][
"lines"
],
"index"
:
subjects
[
k
][
"index"
],
},
"obj_bboxes"
:
[
{
"bbox"
:
objects
[
i
][
"bbox"
],
"lines"
:
objects
[
i
][
"lines"
],
"index"
:
objects
[
i
][
"index"
]}
],
"sub_idx"
:
k
,
}
)
)
seen_sub_idx
.
add
(
k
)
)
seen_idx
.
add
(
k
)
for
i
in
range
(
len
(
subjects
)):
if
i
in
seen_sub_idx
:
continue
ret
.
append
(
{
"sub_bbox"
:
{
"bbox"
:
subjects
[
i
][
"bbox"
],
"lines"
:
subjects
[
i
][
"lines"
],
"index"
:
subjects
[
i
][
"index"
],
},
"obj_bboxes"
:
[],
"sub_idx"
:
i
,
}
)
)
return
ret
# 调用通用方法
return
tie_up_category_by_distance_v3
(
get_subjects
,
get_objects
)
def
get_type_blocks
(
blocks
,
block_type
:
Literal
[
"image"
,
"table"
]):
def
get_type_blocks
(
blocks
,
block_type
:
Literal
[
"image"
,
"table"
]):
...
...
mineru/model/mfr/unimernet/Unimernet.py
View file @
a5583ff4
...
@@ -105,8 +105,8 @@ class UnimernetModel(object):
...
@@ -105,8 +105,8 @@ class UnimernetModel(object):
# Create dataset with sorted images
# Create dataset with sorted images
dataset
=
MathDataset
(
sorted_images
,
transform
=
self
.
model
.
transform
)
dataset
=
MathDataset
(
sorted_images
,
transform
=
self
.
model
.
transform
)
# 如果batch_size> len(sorted_images),则设置为不超过len(sorted_images)的2的幂
# 如果batch_size
> len(sorted_images),则设置为不超过len(sorted_images)的2的幂
batch_size
=
min
(
batch_size
,
2
**
(
len
(
sorted_images
).
bit_length
()
-
1
))
batch_size
=
min
(
batch_size
,
max
(
1
,
2
**
(
len
(
sorted_images
).
bit_length
()
-
1
))
)
if
sorted_images
else
1
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
num_workers
=
0
)
dataloader
=
DataLoader
(
dataset
,
batch_size
=
batch_size
,
num_workers
=
0
)
...
...
mineru/utils/magic_model_utils.py
0 → 100644
View file @
a5583ff4
"""
包含两个MagicModel类中重复使用的方法和逻辑
"""
from
typing
import
List
,
Dict
,
Any
,
Callable
from
mineru.utils.boxbase
import
bbox_distance
,
is_in
def
reduct_overlap
(
bboxes
:
List
[
Dict
[
str
,
Any
]])
->
List
[
Dict
[
str
,
Any
]]:
"""
去除重叠的bbox,保留不被其他bbox包含的bbox
Args:
bboxes: 包含bbox信息的字典列表
Returns:
去重后的bbox列表
"""
N
=
len
(
bboxes
)
keep
=
[
True
]
*
N
for
i
in
range
(
N
):
for
j
in
range
(
N
):
if
i
==
j
:
continue
if
is_in
(
bboxes
[
i
][
'bbox'
],
bboxes
[
j
][
'bbox'
]):
keep
[
i
]
=
False
return
[
bboxes
[
i
]
for
i
in
range
(
N
)
if
keep
[
i
]]
def
tie_up_category_by_distance_v3
(
get_subjects_func
:
Callable
,
get_objects_func
:
Callable
,
extract_subject_func
:
Callable
=
None
,
extract_object_func
:
Callable
=
None
):
"""
通用的类别关联方法,用于将主体对象与客体对象进行关联
参数:
get_subjects_func: 函数,提取主体对象
get_objects_func: 函数,提取客体对象
extract_subject_func: 函数,自定义提取主体属性(默认使用bbox和其他属性)
extract_object_func: 函数,自定义提取客体属性(默认使用bbox和其他属性)
返回:
关联后的对象列表
"""
subjects
=
get_subjects_func
()
objects
=
get_objects_func
()
# 如果没有提供自定义提取函数,使用默认函数
if
extract_subject_func
is
None
:
extract_subject_func
=
lambda
x
:
x
if
extract_object_func
is
None
:
extract_object_func
=
lambda
x
:
x
ret
=
[]
N
,
M
=
len
(
subjects
),
len
(
objects
)
subjects
.
sort
(
key
=
lambda
x
:
x
[
"bbox"
][
0
]
**
2
+
x
[
"bbox"
][
1
]
**
2
)
objects
.
sort
(
key
=
lambda
x
:
x
[
"bbox"
][
0
]
**
2
+
x
[
"bbox"
][
1
]
**
2
)
OBJ_IDX_OFFSET
=
10000
SUB_BIT_KIND
,
OBJ_BIT_KIND
=
0
,
1
all_boxes_with_idx
=
[(
i
,
SUB_BIT_KIND
,
sub
[
"bbox"
][
0
],
sub
[
"bbox"
][
1
])
for
i
,
sub
in
enumerate
(
subjects
)]
+
[
(
i
+
OBJ_IDX_OFFSET
,
OBJ_BIT_KIND
,
obj
[
"bbox"
][
0
],
obj
[
"bbox"
][
1
])
for
i
,
obj
in
enumerate
(
objects
)
]
seen_idx
=
set
()
seen_sub_idx
=
set
()
while
N
>
len
(
seen_sub_idx
):
candidates
=
[]
for
idx
,
kind
,
x0
,
y0
in
all_boxes_with_idx
:
if
idx
in
seen_idx
:
continue
candidates
.
append
((
idx
,
kind
,
x0
,
y0
))
if
len
(
candidates
)
==
0
:
break
left_x
=
min
([
v
[
2
]
for
v
in
candidates
])
top_y
=
min
([
v
[
3
]
for
v
in
candidates
])
candidates
.
sort
(
key
=
lambda
x
:
(
x
[
2
]
-
left_x
)
**
2
+
(
x
[
3
]
-
top_y
)
**
2
)
fst_idx
,
fst_kind
,
left_x
,
top_y
=
candidates
[
0
]
fst_bbox
=
subjects
[
fst_idx
][
'bbox'
]
if
fst_kind
==
SUB_BIT_KIND
else
objects
[
fst_idx
-
OBJ_IDX_OFFSET
][
'bbox'
]
candidates
.
sort
(
key
=
lambda
x
:
bbox_distance
(
fst_bbox
,
subjects
[
x
[
0
]][
'bbox'
])
if
x
[
1
]
==
SUB_BIT_KIND
else
bbox_distance
(
fst_bbox
,
objects
[
x
[
0
]
-
OBJ_IDX_OFFSET
][
'bbox'
]))
nxt
=
None
for
i
in
range
(
1
,
len
(
candidates
)):
if
candidates
[
i
][
1
]
^
fst_kind
==
1
:
nxt
=
candidates
[
i
]
break
if
nxt
is
None
:
break
if
fst_kind
==
SUB_BIT_KIND
:
sub_idx
,
obj_idx
=
fst_idx
,
nxt
[
0
]
-
OBJ_IDX_OFFSET
else
:
sub_idx
,
obj_idx
=
nxt
[
0
],
fst_idx
-
OBJ_IDX_OFFSET
pair_dis
=
bbox_distance
(
subjects
[
sub_idx
][
"bbox"
],
objects
[
obj_idx
][
"bbox"
])
nearest_dis
=
float
(
"inf"
)
for
i
in
range
(
N
):
# 取消原先算法中 1对1 匹配的偏置
# if i in seen_idx or i == sub_idx:continue
nearest_dis
=
min
(
nearest_dis
,
bbox_distance
(
subjects
[
i
][
"bbox"
],
objects
[
obj_idx
][
"bbox"
]))
if
pair_dis
>=
3
*
nearest_dis
:
seen_idx
.
add
(
sub_idx
)
continue
seen_idx
.
add
(
sub_idx
)
seen_idx
.
add
(
obj_idx
+
OBJ_IDX_OFFSET
)
seen_sub_idx
.
add
(
sub_idx
)
ret
.
append
(
{
"sub_bbox"
:
extract_subject_func
(
subjects
[
sub_idx
]),
"obj_bboxes"
:
[
extract_object_func
(
objects
[
obj_idx
])],
"sub_idx"
:
sub_idx
,
}
)
for
i
in
range
(
len
(
objects
)):
j
=
i
+
OBJ_IDX_OFFSET
if
j
in
seen_idx
:
continue
seen_idx
.
add
(
j
)
nearest_dis
,
nearest_sub_idx
=
float
(
"inf"
),
-
1
for
k
in
range
(
len
(
subjects
)):
dis
=
bbox_distance
(
objects
[
i
][
"bbox"
],
subjects
[
k
][
"bbox"
])
if
dis
<
nearest_dis
:
nearest_dis
=
dis
nearest_sub_idx
=
k
for
k
in
range
(
len
(
subjects
)):
if
k
!=
nearest_sub_idx
:
continue
if
k
in
seen_sub_idx
:
for
kk
in
range
(
len
(
ret
)):
if
ret
[
kk
][
"sub_idx"
]
==
k
:
ret
[
kk
][
"obj_bboxes"
].
append
(
extract_object_func
(
objects
[
i
]))
break
else
:
ret
.
append
(
{
"sub_bbox"
:
extract_subject_func
(
subjects
[
k
]),
"obj_bboxes"
:
[
extract_object_func
(
objects
[
i
])],
"sub_idx"
:
k
,
}
)
seen_sub_idx
.
add
(
k
)
seen_idx
.
add
(
k
)
for
i
in
range
(
len
(
subjects
)):
if
i
in
seen_sub_idx
:
continue
ret
.
append
(
{
"sub_bbox"
:
extract_subject_func
(
subjects
[
i
]),
"obj_bboxes"
:
[],
"sub_idx"
:
i
,
}
)
return
ret
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment