Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
95e7e3a7
Unverified
Commit
95e7e3a7
authored
Jul 17, 2024
by
Xiaomeng Zhao
Committed by
GitHub
Jul 17, 2024
Browse files
Merge pull request #160 from icecraft/fix/figure_caption_relation
fix: object cluster algorithm
parents
6a9ad924
ddff4b42
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
40 deletions
+44
-40
magic_pdf/model/magic_model.py
magic_pdf/model/magic_model.py
+44
-40
No files found.
magic_pdf/model/magic_model.py
View file @
95e7e3a7
...
@@ -15,7 +15,8 @@ from magic_pdf.libs.boxbase import (
...
@@ -15,7 +15,8 @@ from magic_pdf.libs.boxbase import (
bbox_relative_pos
,
bbox_relative_pos
,
bbox_distance
,
bbox_distance
,
_is_part_overlap
,
_is_part_overlap
,
calculate_overlap_area_in_bbox1_area_ratio
,
calculate_iou
,
calculate_overlap_area_in_bbox1_area_ratio
,
calculate_iou
,
)
)
from
magic_pdf.libs.ModelBlockTypeEnum
import
ModelBlockTypeEnum
from
magic_pdf.libs.ModelBlockTypeEnum
import
ModelBlockTypeEnum
...
@@ -78,9 +79,23 @@ class MagicModel:
...
@@ -78,9 +79,23 @@ class MagicModel:
for
layout_det2
in
layout_dets
:
for
layout_det2
in
layout_dets
:
if
layout_det1
==
layout_det2
:
if
layout_det1
==
layout_det2
:
continue
continue
if
layout_det1
[
"category_id"
]
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]
and
layout_det2
[
"category_id"
]
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]:
if
layout_det1
[
"category_id"
]
in
[
if
calculate_iou
(
layout_det1
[
'bbox'
],
layout_det2
[
'bbox'
])
>
0.9
:
0
,
if
layout_det1
[
'score'
]
<
layout_det2
[
'score'
]:
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
]
and
layout_det2
[
"category_id"
]
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]:
if
(
calculate_iou
(
layout_det1
[
"bbox"
],
layout_det2
[
"bbox"
])
>
0.9
):
if
layout_det1
[
"score"
]
<
layout_det2
[
"score"
]:
layout_det_need_remove
=
layout_det1
layout_det_need_remove
=
layout_det1
else
:
else
:
layout_det_need_remove
=
layout_det2
layout_det_need_remove
=
layout_det2
...
@@ -97,11 +112,11 @@ class MagicModel:
...
@@ -97,11 +112,11 @@ class MagicModel:
def
__init__
(
self
,
model_list
:
list
,
docs
:
fitz
.
Document
):
def
__init__
(
self
,
model_list
:
list
,
docs
:
fitz
.
Document
):
self
.
__model_list
=
model_list
self
.
__model_list
=
model_list
self
.
__docs
=
docs
self
.
__docs
=
docs
'''
为所有模型数据添加bbox信息(缩放,poly->bbox)
'''
"""
为所有模型数据添加bbox信息(缩放,poly->bbox)
"""
self
.
__fix_axis
()
self
.
__fix_axis
()
'''
删除置信度特别低的模型数据(<0.05),提高质量
'''
"""
删除置信度特别低的模型数据(<0.05),提高质量
"""
self
.
__fix_by_remove_low_confidence
()
self
.
__fix_by_remove_low_confidence
()
'''
删除高iou(>0.9)数据中置信度较低的那个
'''
"""
删除高iou(>0.9)数据中置信度较低的那个
"""
self
.
__fix_by_remove_high_iou_and_low_confidence
()
self
.
__fix_by_remove_high_iou_and_low_confidence
()
def
__reduct_overlap
(
self
,
bboxes
):
def
__reduct_overlap
(
self
,
bboxes
):
...
@@ -125,16 +140,6 @@ class MagicModel:
...
@@ -125,16 +140,6 @@ class MagicModel:
ret
=
[]
ret
=
[]
MAX_DIS_OF_POINT
=
10
**
9
+
7
MAX_DIS_OF_POINT
=
10
**
9
+
7
def
expand_bbox
(
bbox1
,
bbox2
):
x0
=
min
(
bbox1
[
0
],
bbox2
[
0
])
y0
=
min
(
bbox1
[
1
],
bbox2
[
1
])
x1
=
max
(
bbox1
[
2
],
bbox2
[
2
])
y1
=
max
(
bbox1
[
3
],
bbox2
[
3
])
return
[
x0
,
y0
,
x1
,
y1
]
def
get_bbox_area
(
bbox
):
return
abs
(
bbox
[
2
]
-
bbox
[
0
])
*
abs
(
bbox
[
3
]
-
bbox
[
1
])
# subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
# subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
# 再求出筛选出的 subjects 和 object 的最短距离!
# 再求出筛选出的 subjects 和 object 的最短距离!
def
may_find_other_nearest_bbox
(
subject_idx
,
object_idx
):
def
may_find_other_nearest_bbox
(
subject_idx
,
object_idx
):
...
@@ -177,6 +182,13 @@ class MagicModel:
...
@@ -177,6 +182,13 @@ class MagicModel:
return
ret
return
ret
def
expand_bbbox
(
idxes
):
x0s
=
[
all_bboxes
[
idx
][
"bbox"
][
0
]
for
idx
in
idxes
]
y0s
=
[
all_bboxes
[
idx
][
"bbox"
][
1
]
for
idx
in
idxes
]
x1s
=
[
all_bboxes
[
idx
][
"bbox"
][
2
]
for
idx
in
idxes
]
y1s
=
[
all_bboxes
[
idx
][
"bbox"
][
3
]
for
idx
in
idxes
]
return
min
(
x0s
),
min
(
y0s
),
max
(
x1s
),
max
(
y1s
)
subjects
=
self
.
__reduct_overlap
(
subjects
=
self
.
__reduct_overlap
(
list
(
list
(
map
(
map
(
...
@@ -268,7 +280,9 @@ class MagicModel:
...
@@ -268,7 +280,9 @@ class MagicModel:
or
dis
[
i
][
j
]
==
MAX_DIS_OF_POINT
or
dis
[
i
][
j
]
==
MAX_DIS_OF_POINT
):
):
continue
continue
left
,
right
,
_
,
_
=
bbox_relative_pos
(
all_bboxes
[
i
][
"bbox"
],
all_bboxes
[
j
][
"bbox"
])
# 由 pos_flag_count 相关逻辑保证本段逻辑准确性
left
,
right
,
_
,
_
=
bbox_relative_pos
(
all_bboxes
[
i
][
"bbox"
],
all_bboxes
[
j
][
"bbox"
]
)
# 由 pos_flag_count 相关逻辑保证本段逻辑准确性
if
left
or
right
:
if
left
or
right
:
one_way_dis
=
all_bboxes
[
i
][
"bbox"
][
2
]
-
all_bboxes
[
i
][
"bbox"
][
0
]
one_way_dis
=
all_bboxes
[
i
][
"bbox"
][
2
]
-
all_bboxes
[
i
][
"bbox"
][
0
]
else
:
else
:
...
@@ -322,6 +336,10 @@ class MagicModel:
...
@@ -322,6 +336,10 @@ class MagicModel:
break
break
if
is_nearest
:
if
is_nearest
:
nx0
,
ny0
,
nx1
,
ny1
=
expand_bbbox
(
list
(
seen
)
+
[
k
])
n_dis
=
bbox_distance
(
all_bboxes
[
i
][
"bbox"
],
[
nx0
,
ny0
,
nx1
,
ny1
])
if
float_gt
(
dis
[
i
][
j
],
n_dis
):
continue
tmp
.
append
(
k
)
tmp
.
append
(
k
)
seen
.
add
(
k
)
seen
.
add
(
k
)
...
@@ -331,20 +349,7 @@ class MagicModel:
...
@@ -331,20 +349,7 @@ class MagicModel:
# 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
# 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
# 先扩一下 bbox,
# 先扩一下 bbox,
x0s
=
[
all_bboxes
[
idx
][
"bbox"
][
0
]
for
idx
in
seen
]
+
[
ox0
,
oy0
,
ox1
,
oy1
=
expand_bbbox
(
list
(
seen
)
+
[
i
])
all_bboxes
[
i
][
"bbox"
][
0
]
]
y0s
=
[
all_bboxes
[
idx
][
"bbox"
][
1
]
for
idx
in
seen
]
+
[
all_bboxes
[
i
][
"bbox"
][
1
]
]
x1s
=
[
all_bboxes
[
idx
][
"bbox"
][
2
]
for
idx
in
seen
]
+
[
all_bboxes
[
i
][
"bbox"
][
2
]
]
y1s
=
[
all_bboxes
[
idx
][
"bbox"
][
3
]
for
idx
in
seen
]
+
[
all_bboxes
[
i
][
"bbox"
][
3
]
]
ox0
,
oy0
,
ox1
,
oy1
=
min
(
x0s
),
min
(
y0s
),
max
(
x1s
),
max
(
y1s
)
ix0
,
iy0
,
ix1
,
iy1
=
all_bboxes
[
i
][
"bbox"
]
ix0
,
iy0
,
ix1
,
iy1
=
all_bboxes
[
i
][
"bbox"
]
# 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
# 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
...
@@ -455,8 +460,10 @@ class MagicModel:
...
@@ -455,8 +460,10 @@ class MagicModel:
with_caption_subject
.
add
(
j
)
with_caption_subject
.
add
(
j
)
return
ret
,
total_subject_object_dis
return
ret
,
total_subject_object_dis
def
get_imgs
(
self
,
page_no
:
int
):
# @许瑞
def
get_imgs
(
self
,
page_no
:
int
):
records
,
_
=
self
.
__tie_up_category_by_distance
(
page_no
,
3
,
4
)
figure_captions
,
_
=
self
.
__tie_up_category_by_distance
(
page_no
,
3
,
4
)
return
[
return
[
{
{
"bbox"
:
record
[
"all"
],
"bbox"
:
record
[
"all"
],
...
@@ -464,7 +471,7 @@ class MagicModel:
...
@@ -464,7 +471,7 @@ class MagicModel:
"img_caption_bbox"
:
record
.
get
(
"object_body"
,
None
),
"img_caption_bbox"
:
record
.
get
(
"object_body"
,
None
),
"score"
:
record
[
"score"
],
"score"
:
record
[
"score"
],
}
}
for
record
in
record
s
for
record
in
figure_caption
s
]
]
def
get_tables
(
def
get_tables
(
...
@@ -535,6 +542,7 @@ class MagicModel:
...
@@ -535,6 +542,7 @@ class MagicModel:
if
not
any
(
span
==
existing_span
for
existing_span
in
new_spans
):
if
not
any
(
span
==
existing_span
for
existing_span
in
new_spans
):
new_spans
.
append
(
span
)
new_spans
.
append
(
span
)
return
new_spans
return
new_spans
all_spans
=
[]
all_spans
=
[]
model_page_info
=
self
.
__model_list
[
page_no
]
model_page_info
=
self
.
__model_list
[
page_no
]
layout_dets
=
model_page_info
[
"layout_dets"
]
layout_dets
=
model_page_info
[
"layout_dets"
]
...
@@ -548,10 +556,7 @@ class MagicModel:
...
@@ -548,10 +556,7 @@ class MagicModel:
for
layout_det
in
layout_dets
:
for
layout_det
in
layout_dets
:
category_id
=
layout_det
[
"category_id"
]
category_id
=
layout_det
[
"category_id"
]
if
category_id
in
allow_category_id_list
:
if
category_id
in
allow_category_id_list
:
span
=
{
span
=
{
"bbox"
:
layout_det
[
"bbox"
],
"score"
:
layout_det
[
"score"
]}
"bbox"
:
layout_det
[
"bbox"
],
"score"
:
layout_det
[
"score"
]
}
if
category_id
==
3
:
if
category_id
==
3
:
span
[
"type"
]
=
ContentType
.
Image
span
[
"type"
]
=
ContentType
.
Image
elif
category_id
==
5
:
elif
category_id
==
5
:
...
@@ -604,7 +609,6 @@ class MagicModel:
...
@@ -604,7 +609,6 @@ class MagicModel:
return
self
.
__model_list
[
page_no
]
return
self
.
__model_list
[
page_no
]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
drw
=
DiskReaderWriter
(
r
"D:/project/20231108code-clean"
)
drw
=
DiskReaderWriter
(
r
"D:/project/20231108code-clean"
)
if
0
:
if
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment