Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
8f1f9abe
Commit
8f1f9abe
authored
May 30, 2025
by
myhloli
Browse files
refactor: enhance bounding box utilities and add configuration reader for S3 integration
parent
7285ea92
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
462 additions
and
509 deletions
+462
-509
mineru/backend/pipeline/config_reader.py
mineru/backend/pipeline/config_reader.py
+117
-0
mineru/backend/pipeline/model_json_to_middle_json.py
mineru/backend/pipeline/model_json_to_middle_json.py
+24
-2
mineru/backend/pipeline/pipeline_analyze.py
mineru/backend/pipeline/pipeline_analyze.py
+11
-14
mineru/backend/vlm/token_to_middle_json.py
mineru/backend/vlm/token_to_middle_json.py
+1
-1
mineru/libs/__init__.py
mineru/libs/__init__.py
+0
-1
mineru/resources/__init__.py
mineru/resources/__init__.py
+0
-1
mineru/resources/fasttext-langdetect/lid.176.ftz
mineru/resources/fasttext-langdetect/lid.176.ftz
+0
-0
mineru/resources/slanet_plus/slanet-plus.onnx
mineru/resources/slanet_plus/slanet-plus.onnx
+0
-0
mineru/utils/boxbase.py
mineru/utils/boxbase.py
+85
-0
mineru/utils/enum_class.py
mineru/utils/enum_class.py
+16
-0
mineru/utils/language.py
mineru/utils/language.py
+48
-0
mineru/utils/model_utils.py
mineru/utils/model_utils.py
+1
-1
mineru/utils/pipeline_magic_model.py
mineru/utils/pipeline_magic_model.py
+159
-489
No files found.
mineru/backend/pipeline/config_reader.py
0 → 100644
View file @
8f1f9abe
# Copyright (c) Opendatalab. All rights reserved.
import
json
import
os
from
loguru
import
logger
# 定义配置文件名常量
CONFIG_FILE_NAME
=
os
.
getenv
(
'MINERU_TOOLS_CONFIG_JSON'
,
'magic-pdf.json'
)
def
read_config
():
if
os
.
path
.
isabs
(
CONFIG_FILE_NAME
):
config_file
=
CONFIG_FILE_NAME
else
:
home_dir
=
os
.
path
.
expanduser
(
'~'
)
config_file
=
os
.
path
.
join
(
home_dir
,
CONFIG_FILE_NAME
)
if
not
os
.
path
.
exists
(
config_file
):
raise
FileNotFoundError
(
f
'
{
config_file
}
not found'
)
with
open
(
config_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
config
=
json
.
load
(
f
)
return
config
def
get_s3_config
(
bucket_name
:
str
):
"""~/magic-pdf.json 读出来."""
config
=
read_config
()
bucket_info
=
config
.
get
(
'bucket_info'
)
if
bucket_name
not
in
bucket_info
:
access_key
,
secret_key
,
storage_endpoint
=
bucket_info
[
'[default]'
]
else
:
access_key
,
secret_key
,
storage_endpoint
=
bucket_info
[
bucket_name
]
if
access_key
is
None
or
secret_key
is
None
or
storage_endpoint
is
None
:
raise
Exception
(
f
'ak, sk or endpoint not found in
{
CONFIG_FILE_NAME
}
'
)
# logger.info(f"get_s3_config: ak={access_key}, sk={secret_key}, endpoint={storage_endpoint}")
return
access_key
,
secret_key
,
storage_endpoint
def
get_s3_config_dict
(
path
:
str
):
access_key
,
secret_key
,
storage_endpoint
=
get_s3_config
(
get_bucket_name
(
path
))
return
{
'ak'
:
access_key
,
'sk'
:
secret_key
,
'endpoint'
:
storage_endpoint
}
def
get_bucket_name
(
path
):
bucket
,
key
=
parse_bucket_key
(
path
)
return
bucket
def
parse_bucket_key
(
s3_full_path
:
str
):
"""
输入 s3://bucket/path/to/my/file.txt
输出 bucket, path/to/my/file.txt
"""
s3_full_path
=
s3_full_path
.
strip
()
if
s3_full_path
.
startswith
(
"s3://"
):
s3_full_path
=
s3_full_path
[
5
:]
if
s3_full_path
.
startswith
(
"/"
):
s3_full_path
=
s3_full_path
[
1
:]
bucket
,
key
=
s3_full_path
.
split
(
"/"
,
1
)
return
bucket
,
key
def
get_local_models_dir
():
config
=
read_config
()
models_dir
=
config
.
get
(
'models-dir'
)
if
models_dir
is
None
:
logger
.
warning
(
f
"'models-dir' not found in
{
CONFIG_FILE_NAME
}
, use '/tmp/models' as default"
)
return
'/tmp/models'
else
:
return
models_dir
def
get_local_layoutreader_model_dir
():
config
=
read_config
()
layoutreader_model_dir
=
config
.
get
(
'layoutreader-model-dir'
)
if
layoutreader_model_dir
is
None
or
not
os
.
path
.
exists
(
layoutreader_model_dir
):
home_dir
=
os
.
path
.
expanduser
(
'~'
)
layoutreader_at_modelscope_dir_path
=
os
.
path
.
join
(
home_dir
,
'.cache/modelscope/hub/ppaanngggg/layoutreader'
)
logger
.
warning
(
f
"'layoutreader-model-dir' not exists, use
{
layoutreader_at_modelscope_dir_path
}
as default"
)
return
layoutreader_at_modelscope_dir_path
else
:
return
layoutreader_model_dir
def
get_device
():
config
=
read_config
()
device
=
config
.
get
(
'device-mode'
)
if
device
is
None
:
logger
.
warning
(
f
"'device-mode' not found in
{
CONFIG_FILE_NAME
}
, use 'cpu' as default"
)
return
'cpu'
else
:
return
device
def
get_table_recog_config
():
config
=
read_config
()
table_config
=
config
.
get
(
'table-config'
)
if
table_config
is
None
:
logger
.
warning
(
f
"'table-config' not found in
{
CONFIG_FILE_NAME
}
, use 'False' as default"
)
return
json
.
loads
(
f
'{{"enable": true}}'
)
else
:
return
table_config
def
get_formula_config
():
config
=
read_config
()
formula_config
=
config
.
get
(
'formula-config'
)
if
formula_config
is
None
:
logger
.
warning
(
f
"'formula-config' not found in
{
CONFIG_FILE_NAME
}
, use 'True' as default"
)
return
json
.
loads
(
f
'{{"enable": true}}'
)
else
:
return
formula_config
\ No newline at end of file
mineru/backend/pipeline/model_json_to_middle_json.py
View file @
8f1f9abe
# Copyright (c) Opendatalab. All rights reserved.
# Copyright (c) Opendatalab. All rights reserved.
def
result_to_middle_json
(
model_json
,
images_list
,
pdf_doc
,
image_writer
):
from
mineru.utils.pipeline_magic_model
import
MagicModel
pass
from
mineru.version
import
__version__
\ No newline at end of file
from
mineru.utils.hash_utils
import
str_md5
def
page_model_info_to_page_info
(
page_model_info
,
image_dict
,
page
,
image_writer
,
page_index
,
lang
=
None
,
ocr
=
False
):
scale
=
image_dict
[
"scale"
]
page_pil_img
=
image_dict
[
"img_pil"
]
page_img_md5
=
str_md5
(
image_dict
[
"img_base64"
])
width
,
height
=
map
(
int
,
page
.
get_size
())
magic_model
=
MagicModel
(
page_model_info
,
scale
)
def
result_to_middle_json
(
model_list
,
images_list
,
pdf_doc
,
image_writer
,
lang
=
None
,
ocr
=
False
):
middle_json
=
{
"pdf_info"
:
[],
"_backend"
:
"vlm"
,
"_version_name"
:
__version__
}
for
page_index
,
page_model_info
in
enumerate
(
model_list
):
page
=
pdf_doc
[
page_index
]
image_dict
=
images_list
[
page_index
]
page_info
=
page_model_info_to_page_info
(
page_model_info
,
image_dict
,
page
,
image_writer
,
page_index
,
lang
=
lang
,
ocr
=
ocr
)
middle_json
[
"pdf_info"
].
append
(
page_info
)
return
middle_json
\ No newline at end of file
mineru/backend/pipeline/pipeline_analyze.py
View file @
8f1f9abe
...
@@ -2,9 +2,9 @@ import os
...
@@ -2,9 +2,9 @@ import os
import
time
import
time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
pypdfium2
import
PdfDocument
from
mineru.backend.pipeline.model_init
import
MineruPipelineModel
from
.model_init
import
MineruPipelineModel
from
.config_reader
import
get_local_models_dir
,
get_device
,
get_formula_config
,
get_table_recog_config
from
.model_json_to_middle_json
import
result_to_middle_json
from
.model_json_to_middle_json
import
result_to_middle_json
from
...data.data_reader_writer
import
DataWriter
from
...data.data_reader_writer
import
DataWriter
from
...utils.pdf_classify
import
classify
from
...utils.pdf_classify
import
classify
...
@@ -13,11 +13,6 @@ from ...utils.pdf_image_tools import load_images_from_pdf
...
@@ -13,11 +13,6 @@ from ...utils.pdf_image_tools import load_images_from_pdf
from
loguru
import
logger
from
loguru
import
logger
from
...utils.model_utils
import
get_vram
,
clean_memory
from
...utils.model_utils
import
get_vram
,
clean_memory
from
magic_pdf.libs.config_reader
import
(
get_device
,
get_formula_config
,
get_layout_config
,
get_local_models_dir
,
get_table_recog_config
)
os
.
environ
[
'PYTORCH_ENABLE_MPS_FALLBACK'
]
=
'1'
# 让mps可以fallback
os
.
environ
[
'PYTORCH_ENABLE_MPS_FALLBACK'
]
=
'1'
# 让mps可以fallback
...
@@ -109,6 +104,7 @@ def doc_analyze(
...
@@ -109,6 +104,7 @@ def doc_analyze(
all_image_lists
=
[]
all_image_lists
=
[]
all_pdf_docs
=
[]
all_pdf_docs
=
[]
ocr_enabled_list
=
[]
for
pdf_idx
,
pdf_bytes
in
enumerate
(
pdf_bytes_list
):
for
pdf_idx
,
pdf_bytes
in
enumerate
(
pdf_bytes_list
):
# 确定OCR设置
# 确定OCR设置
_ocr
=
False
_ocr
=
False
...
@@ -118,6 +114,7 @@ def doc_analyze(
...
@@ -118,6 +114,7 @@ def doc_analyze(
elif
parse_method
==
'ocr'
:
elif
parse_method
==
'ocr'
:
_ocr
=
True
_ocr
=
True
ocr_enabled_list
[
pdf_idx
]
=
_ocr
_lang
=
lang_list
[
pdf_idx
]
_lang
=
lang_list
[
pdf_idx
]
# 收集每个数据集中的页面
# 收集每个数据集中的页面
...
@@ -152,23 +149,23 @@ def doc_analyze(
...
@@ -152,23 +149,23 @@ def doc_analyze(
results
.
extend
(
batch_results
)
results
.
extend
(
batch_results
)
# 构建返回结果
# 构建返回结果
infer_results
=
[]
# 多数据集模式:按数据集分组结果
infer_results
=
[[]
for
_
in
datasets
]
for
i
,
page_info
in
enumerate
(
all_pages_info
):
for
i
,
page_info
in
enumerate
(
all_pages_info
):
pdf_idx
,
page_idx
,
pil_img
,
_
,
_
=
page_info
pdf_idx
,
page_idx
,
pil_img
,
_
,
_
=
page_info
result
=
results
[
i
]
result
=
results
[
i
]
page_info_dict
=
{
'page_no'
:
page_idx
,
'width'
:
pil_img
.
get_
width
()
,
'height'
:
pil_img
.
get_
height
()
}
page_info_dict
=
{
'page_no'
:
page_idx
,
'width'
:
pil_img
.
width
,
'height'
:
pil_img
.
height
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info_dict
}
page_dict
=
{
'layout_dets'
:
result
,
'page_info'
:
page_info_dict
}
infer_results
[
pdf_idx
]
.
append
(
page_dict
)
infer_results
[
pdf_idx
]
[
page_idx
]
=
page_dict
middle_json_list
=
[]
middle_json_list
=
[]
for
pdf_idx
,
model_
json
in
enumerate
(
infer_results
):
for
pdf_idx
,
model_
list
in
enumerate
(
infer_results
):
images_list
=
all_image_lists
[
pdf_idx
]
images_list
=
all_image_lists
[
pdf_idx
]
pdf_doc
=
all_pdf_docs
[
pdf_idx
]
pdf_doc
=
all_pdf_docs
[
pdf_idx
]
middle_json
=
result_to_middle_json
(
model_json
,
images_list
,
pdf_doc
,
image_writer
)
_lang
=
lang_list
[
pdf_idx
]
_ocr
=
ocr_enabled_list
[
pdf_idx
]
middle_json
=
result_to_middle_json
(
model_list
,
images_list
,
pdf_doc
,
image_writer
,
_lang
,
_ocr
)
middle_json_list
.
append
(
middle_json
)
middle_json_list
.
append
(
middle_json
)
return
middle_json_list
,
infer_results
return
middle_json_list
,
infer_results
...
...
mineru/backend/vlm/token_to_middle_json.py
View file @
8f1f9abe
...
@@ -118,7 +118,7 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
...
@@ -118,7 +118,7 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
def
result_to_middle_json
(
token_list
,
images_list
,
pdf_doc
,
image_writer
):
def
result_to_middle_json
(
token_list
,
images_list
,
pdf_doc
,
image_writer
):
middle_json
=
{
"pdf_info"
:
[],
"_version_name"
:
__version__
}
middle_json
=
{
"pdf_info"
:
[],
"_backend"
:
"vlm"
,
"_version_name"
:
__version__
}
for
index
,
token
in
enumerate
(
token_list
):
for
index
,
token
in
enumerate
(
token_list
):
page
=
pdf_doc
[
index
]
page
=
pdf_doc
[
index
]
image_dict
=
images_list
[
index
]
image_dict
=
images_list
[
index
]
...
...
mineru/libs/__init__.py
deleted
100644 → 0
View file @
7285ea92
# Copyright (c) Opendatalab. All rights reserved.
mineru/resources/__init__.py
deleted
100644 → 0
View file @
7285ea92
# Copyright (c) Opendatalab. All rights reserved.
mineru/resources/fasttext-langdetect/lid.176.ftz
0 → 100644
View file @
8f1f9abe
File added
mineru/resources/slanet_plus/slanet-plus.onnx
0 → 100644
View file @
8f1f9abe
File added
mineru/utils/boxbase.py
View file @
8f1f9abe
...
@@ -72,3 +72,88 @@ def bbox_distance(bbox1, bbox2):
...
@@ -72,3 +72,88 @@ def bbox_distance(bbox1, bbox2):
elif
top
:
elif
top
:
return
y2
-
y1b
return
y2
-
y1b
return
0.0
return
0.0
def
get_minbox_if_overlap_by_ratio
(
bbox1
,
bbox2
,
ratio
):
"""通过calculate_overlap_area_2_minbox_area_ratio计算两个bbox重叠的面积占最小面积的box的比例
如果比例大于ratio,则返回小的那个bbox, 否则返回None."""
x1_min
,
y1_min
,
x1_max
,
y1_max
=
bbox1
x2_min
,
y2_min
,
x2_max
,
y2_max
=
bbox2
area1
=
(
x1_max
-
x1_min
)
*
(
y1_max
-
y1_min
)
area2
=
(
x2_max
-
x2_min
)
*
(
y2_max
-
y2_min
)
overlap_ratio
=
calculate_overlap_area_2_minbox_area_ratio
(
bbox1
,
bbox2
)
if
overlap_ratio
>
ratio
:
if
area1
<=
area2
:
return
bbox1
else
:
return
bbox2
else
:
return
None
def
calculate_overlap_area_2_minbox_area_ratio
(
bbox1
,
bbox2
):
"""计算box1和box2的重叠面积占最小面积的box的比例."""
# Determine the coordinates of the intersection rectangle
x_left
=
max
(
bbox1
[
0
],
bbox2
[
0
])
y_top
=
max
(
bbox1
[
1
],
bbox2
[
1
])
x_right
=
min
(
bbox1
[
2
],
bbox2
[
2
])
y_bottom
=
min
(
bbox1
[
3
],
bbox2
[
3
])
if
x_right
<
x_left
or
y_bottom
<
y_top
:
return
0.0
# The area of overlap area
intersection_area
=
(
x_right
-
x_left
)
*
(
y_bottom
-
y_top
)
min_box_area
=
min
([(
bbox1
[
2
]
-
bbox1
[
0
])
*
(
bbox1
[
3
]
-
bbox1
[
1
]),
(
bbox2
[
3
]
-
bbox2
[
1
])
*
(
bbox2
[
2
]
-
bbox2
[
0
])])
if
min_box_area
==
0
:
return
0
else
:
return
intersection_area
/
min_box_area
def
calculate_iou
(
bbox1
,
bbox2
):
"""计算两个边界框的交并比(IOU)。
Args:
bbox1 (list[float]): 第一个边界框的坐标,格式为 [x1, y1, x2, y2],其中 (x1, y1) 为左上角坐标,(x2, y2) 为右下角坐标。
bbox2 (list[float]): 第二个边界框的坐标,格式与 `bbox1` 相同。
Returns:
float: 两个边界框的交并比(IOU),取值范围为 [0, 1]。
"""
# Determine the coordinates of the intersection rectangle
x_left
=
max
(
bbox1
[
0
],
bbox2
[
0
])
y_top
=
max
(
bbox1
[
1
],
bbox2
[
1
])
x_right
=
min
(
bbox1
[
2
],
bbox2
[
2
])
y_bottom
=
min
(
bbox1
[
3
],
bbox2
[
3
])
if
x_right
<
x_left
or
y_bottom
<
y_top
:
return
0.0
# The area of overlap area
intersection_area
=
(
x_right
-
x_left
)
*
(
y_bottom
-
y_top
)
# The area of both rectangles
bbox1_area
=
(
bbox1
[
2
]
-
bbox1
[
0
])
*
(
bbox1
[
3
]
-
bbox1
[
1
])
bbox2_area
=
(
bbox2
[
2
]
-
bbox2
[
0
])
*
(
bbox2
[
3
]
-
bbox2
[
1
])
if
any
([
bbox1_area
==
0
,
bbox2_area
==
0
]):
return
0
# Compute the intersection over union by taking the intersection area
# and dividing it by the sum of both areas minus the intersection area
iou
=
intersection_area
/
float
(
bbox1_area
+
bbox2_area
-
intersection_area
)
return
iou
def
_is_in
(
box1
,
box2
)
->
bool
:
"""box1是否完全在box2里面."""
x0_1
,
y0_1
,
x1_1
,
y1_1
=
box1
x0_2
,
y0_2
,
x1_2
,
y1_2
=
box2
return
(
x0_1
>=
x0_2
and
# box1的左边界不在box2的左边外
y0_1
>=
y0_2
and
# box1的上边界不在box2的上边外
x1_1
<=
x1_2
and
# box1的右边界不在box2的右边外
y1_1
<=
y1_2
)
# box1的下边界不在box2的下边外
\ No newline at end of file
mineru/utils/enum_class.py
View file @
8f1f9abe
...
@@ -23,6 +23,22 @@ class ContentType:
...
@@ -23,6 +23,22 @@ class ContentType:
INLINE_EQUATION
=
'inline_equation'
INLINE_EQUATION
=
'inline_equation'
class
CategoryId
:
Title
=
0
Text
=
1
Abandon
=
2
ImageBody
=
3
ImageCaption
=
4
TableBody
=
5
TableCaption
=
6
TableFootnote
=
7
InterlineEquation_Layout
=
8
InlineEquation
=
13
InterlineEquation_YOLO
=
14
OcrText
=
15
ImageFootnote
=
101
class
MakeMode
:
class
MakeMode
:
MM_MD
=
'mm_markdown'
MM_MD
=
'mm_markdown'
NLP_MD
=
'nlp_markdown'
NLP_MD
=
'nlp_markdown'
...
...
mineru/utils/language.py
0 → 100644
View file @
8f1f9abe
import
os
import
unicodedata
if
not
os
.
getenv
(
"FTLANG_CACHE"
):
current_file_path
=
os
.
path
.
abspath
(
__file__
)
current_dir
=
os
.
path
.
dirname
(
current_file_path
)
root_dir
=
os
.
path
.
dirname
(
current_dir
)
ftlang_cache_dir
=
os
.
path
.
join
(
root_dir
,
'resources'
,
'fasttext-langdetect'
)
os
.
environ
[
"FTLANG_CACHE"
]
=
str
(
ftlang_cache_dir
)
# print(os.getenv("FTLANG_CACHE"))
from
fast_langdetect
import
detect_language
def
remove_invalid_surrogates
(
text
):
# 移除无效的 UTF-16 代理对
return
''
.
join
(
c
for
c
in
text
if
not
(
0xD800
<=
ord
(
c
)
<=
0xDFFF
))
def
detect_lang
(
text
:
str
)
->
str
:
if
len
(
text
)
==
0
:
return
""
text
=
text
.
replace
(
"
\n
"
,
""
)
text
=
remove_invalid_surrogates
(
text
)
# print(text)
try
:
lang_upper
=
detect_language
(
text
)
except
:
html_no_ctrl_chars
=
''
.
join
([
l
for
l
in
text
if
unicodedata
.
category
(
l
)[
0
]
not
in
[
'C'
,
]])
lang_upper
=
detect_language
(
html_no_ctrl_chars
)
try
:
lang
=
lang_upper
.
lower
()
except
:
lang
=
""
return
lang
if
__name__
==
'__main__'
:
print
(
os
.
getenv
(
"FTLANG_CACHE"
))
print
(
detect_lang
(
"This is a test."
))
print
(
detect_lang
(
"<html>This is a test</html>"
))
print
(
detect_lang
(
"这个是中文测试。"
))
print
(
detect_lang
(
"<html>这个是中文测试。</html>"
))
print
(
detect_lang
(
"〖
\ud835\udc46\ud835
〗这是个包含utf-16的中文测试"
))
\ No newline at end of file
mineru/utils/model_utils.py
View file @
8f1f9abe
...
@@ -4,7 +4,7 @@ import gc
...
@@ -4,7 +4,7 @@ import gc
from
loguru
import
logger
from
loguru
import
logger
import
numpy
as
np
import
numpy
as
np
from
m
agic_pdf.lib
s.boxbase
import
get_minbox_if_overlap_by_ratio
from
m
ineru.util
s.boxbase
import
get_minbox_if_overlap_by_ratio
def
crop_img
(
input_res
,
input_np_img
,
crop_paste_x
=
0
,
crop_paste_y
=
0
):
def
crop_img
(
input_res
,
input_np_img
,
crop_paste_x
=
0
,
crop_paste_y
=
0
):
...
...
mineru/utils/pipeline_magic_model.py
View file @
8f1f9abe
import
enum
from
mineru.utils.boxbase
import
bbox_relative_pos
,
calculate_iou
,
bbox_distance
,
_is_in
from
mineru.utils.enum_class
import
CategoryId
,
ContentType
from
magic_pdf.config.model_block_type
import
ModelBlockTypeEnum
from
magic_pdf.config.ocr_content_type
import
CategoryId
,
ContentType
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.boxbase
import
(
_is_in
,
bbox_distance
,
bbox_relative_pos
,
calculate_iou
)
from
magic_pdf.libs.coordinate_transform
import
get_scale_ratio
from
magic_pdf.pre_proc.remove_bbox_overlap
import
_remove_overlap_between_bbox
CAPATION_OVERLAP_AREA_RATIO
=
0.6
MERGE_BOX_OVERLAP_AREA_RATIO
=
1.1
class
PosRelationEnum
(
enum
.
Enum
):
LEFT
=
'left'
RIGHT
=
'right'
UP
=
'up'
BOTTOM
=
'bottom'
ALL
=
'all'
class
MagicModel
:
class
MagicModel
:
"""每个函数没有得到元素的时候返回空list."""
"""每个函数没有得到元素的时候返回空list."""
def
__init__
(
self
,
page_model_info
:
dict
,
scale
:
float
):
self
.
__page_model_info
=
page_model_info
self
.
__scale
=
scale
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
self
.
__fix_axis
()
"""删除置信度特别低的模型数据(<0.05),提高质量"""
self
.
__fix_by_remove_low_confidence
()
"""删除高iou(>0.9)数据中置信度较低的那个"""
self
.
__fix_by_remove_high_iou_and_low_confidence
()
self
.
__fix_footnote
()
def
__fix_axis
(
self
):
def
__fix_axis
(
self
):
for
model_page_info
in
self
.
__model_list
:
need_remove_list
=
[]
need_remove_list
=
[]
layout_dets
=
self
.
__page_model_info
[
'layout_dets'
]
page_no
=
model_page_info
[
'page_info'
][
'page_no'
]
for
layout_det
in
layout_dets
:
horizontal_scale_ratio
,
vertical_scale_ratio
=
get_scale_ratio
(
x0
,
y0
,
_
,
_
,
x1
,
y1
,
_
,
_
=
layout_det
[
'poly'
]
model_page_info
,
self
.
__docs
.
get_page
(
page_no
)
bbox
=
[
)
int
(
x0
/
self
.
__scale
),
layout_dets
=
model_page_info
[
'layout_dets'
]
int
(
y0
/
self
.
__scale
),
for
layout_det
in
layout_dets
:
int
(
x1
/
self
.
__scale
),
int
(
y1
/
self
.
__scale
),
if
layout_det
.
get
(
'bbox'
)
is
not
None
:
]
# 兼容直接输出bbox的模型数据,如paddle
layout_det
[
'bbox'
]
=
bbox
x0
,
y0
,
x1
,
y1
=
layout_det
[
'bbox'
]
# 删除高度或者宽度小于等于0的spans
else
:
if
bbox
[
2
]
-
bbox
[
0
]
<=
0
or
bbox
[
3
]
-
bbox
[
1
]
<=
0
:
# 兼容直接输出poly的模型数据,如xxx
need_remove_list
.
append
(
layout_det
)
x0
,
y0
,
_
,
_
,
x1
,
y1
,
_
,
_
=
layout_det
[
'poly'
]
for
need_remove
in
need_remove_list
:
layout_dets
.
remove
(
need_remove
)
bbox
=
[
int
(
x0
/
horizontal_scale_ratio
),
int
(
y0
/
vertical_scale_ratio
),
int
(
x1
/
horizontal_scale_ratio
),
int
(
y1
/
vertical_scale_ratio
),
]
layout_det
[
'bbox'
]
=
bbox
# 删除高度或者宽度小于等于0的spans
if
bbox
[
2
]
-
bbox
[
0
]
<=
0
or
bbox
[
3
]
-
bbox
[
1
]
<=
0
:
need_remove_list
.
append
(
layout_det
)
for
need_remove
in
need_remove_list
:
layout_dets
.
remove
(
need_remove
)
def
__fix_by_remove_low_confidence
(
self
):
def
__fix_by_remove_low_confidence
(
self
):
for
model_page_info
in
self
.
__model_list
:
need_remove_list
=
[]
need_remove_list
=
[]
layout_dets
=
self
.
__page_model_info
[
'layout_dets'
]
layout_dets
=
model_page_info
[
'layout_dets'
]
for
layout_det
in
layout_dets
:
for
layout_det
in
layout_dets
:
if
layout_det
[
'score'
]
<=
0.05
:
if
layout_det
[
'score'
]
<=
0.05
:
need_remove_list
.
append
(
layout_det
)
need_remove_list
.
append
(
layout_det
)
else
:
else
:
continue
continue
for
need_remove
in
need_remove_list
:
for
need_remove
in
need_remove_list
:
layout_dets
.
remove
(
need_remove
)
layout_dets
.
remove
(
need_remove
)
def
__fix_by_remove_high_iou_and_low_confidence
(
self
):
def
__fix_by_remove_high_iou_and_low_confidence
(
self
):
for
model_page_info
in
self
.
__model_list
:
need_remove_list
=
[]
need_remove_list
=
[]
layout_dets
=
self
.
__page_model_info
[
'layout_dets'
]
layout_dets
=
model_page_info
[
'layout_dets'
]
for
layout_det1
in
layout_dets
:
for
layout_det1
in
layout_dets
:
for
layout_det2
in
layout_dets
:
for
layout_det2
in
layout_dets
:
if
layout_det1
==
layout_det2
:
if
layout_det1
==
layout_det2
:
continue
continue
if
layout_det1
[
'category_id'
]
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]
and
layout_det2
[
'category_id'
]
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]:
if
layout_det1
[
'category_id'
]
in
[
if
(
0
,
calculate_iou
(
layout_det1
[
'bbox'
],
layout_det2
[
'bbox'
])
1
,
>
0.9
2
,
):
3
,
if
layout_det1
[
'score'
]
<
layout_det2
[
'score'
]:
4
,
layout_det_need_remove
=
layout_det1
5
,
6
,
7
,
8
,
9
,
]
and
layout_det2
[
'category_id'
]
in
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]:
if
(
calculate_iou
(
layout_det1
[
'bbox'
],
layout_det2
[
'bbox'
])
>
0.9
):
if
layout_det1
[
'score'
]
<
layout_det2
[
'score'
]:
layout_det_need_remove
=
layout_det1
else
:
layout_det_need_remove
=
layout_det2
if
layout_det_need_remove
not
in
need_remove_list
:
need_remove_list
.
append
(
layout_det_need_remove
)
else
:
else
:
continue
layout_det_need_remove
=
layout_det2
if
layout_det_need_remove
not
in
need_remove_list
:
need_remove_list
.
append
(
layout_det_need_remove
)
else
:
else
:
continue
continue
for
need_remove
in
need_remove_list
:
else
:
layout_dets
.
remove
(
need_remove
)
continue
for
need_remove
in
need_remove_list
:
layout_dets
.
remove
(
need_remove
)
def
__init__
(
self
,
model_list
:
list
,
docs
:
Dataset
):
def
__fix_footnote
(
self
):
self
.
__model_list
=
model_list
# 3: figure, 5: table, 7: footnote
self
.
__docs
=
docs
footnotes
=
[]
"""为所有模型数据添加bbox信息(缩放,poly->bbox)"""
figures
=
[]
self
.
__fix_axis
()
tables
=
[]
"""删除置信度特别低的模型数据(<0.05),提高质量"""
self
.
__fix_by_remove_low_confidence
()
for
obj
in
self
.
__page_model_info
[
'layout_dets'
]:
"""删除高iou(>0.9)数据中置信度较低的那个"""
if
obj
[
'category_id'
]
==
7
:
self
.
__fix_by_remove_high_iou_and_low_confidence
()
footnotes
.
append
(
obj
)
self
.
__fix_footnote
()
elif
obj
[
'category_id'
]
==
3
:
figures
.
append
(
obj
)
elif
obj
[
'category_id'
]
==
5
:
tables
.
append
(
obj
)
if
len
(
footnotes
)
*
len
(
figures
)
==
0
:
continue
dis_figure_footnote
=
{}
dis_table_footnote
=
{}
for
i
in
range
(
len
(
footnotes
)):
for
j
in
range
(
len
(
figures
)):
pos_flag_count
=
sum
(
list
(
map
(
lambda
x
:
1
if
x
else
0
,
bbox_relative_pos
(
footnotes
[
i
][
'bbox'
],
figures
[
j
][
'bbox'
]
),
)
)
)
if
pos_flag_count
>
1
:
continue
dis_figure_footnote
[
i
]
=
min
(
self
.
_bbox_distance
(
figures
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]),
dis_figure_footnote
.
get
(
i
,
float
(
'inf'
)),
)
for
i
in
range
(
len
(
footnotes
)):
for
j
in
range
(
len
(
tables
)):
pos_flag_count
=
sum
(
list
(
map
(
lambda
x
:
1
if
x
else
0
,
bbox_relative_pos
(
footnotes
[
i
][
'bbox'
],
tables
[
j
][
'bbox'
]
),
)
)
)
if
pos_flag_count
>
1
:
continue
dis_table_footnote
[
i
]
=
min
(
self
.
_bbox_distance
(
tables
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]),
dis_table_footnote
.
get
(
i
,
float
(
'inf'
)),
)
for
i
in
range
(
len
(
footnotes
)):
if
i
not
in
dis_figure_footnote
:
continue
if
dis_table_footnote
.
get
(
i
,
float
(
'inf'
))
>
dis_figure_footnote
[
i
]:
footnotes
[
i
][
'category_id'
]
=
CategoryId
.
ImageFootnote
def
_bbox_distance
(
self
,
bbox1
,
bbox2
):
def
_bbox_distance
(
self
,
bbox1
,
bbox2
):
left
,
right
,
bottom
,
top
=
bbox_relative_pos
(
bbox1
,
bbox2
)
left
,
right
,
bottom
,
top
=
bbox_relative_pos
(
bbox1
,
bbox2
)
...
@@ -132,68 +149,6 @@ class MagicModel:
...
@@ -132,68 +149,6 @@ class MagicModel:
return
bbox_distance
(
bbox1
,
bbox2
)
return
bbox_distance
(
bbox1
,
bbox2
)
def
__fix_footnote
(
self
):
# 3: figure, 5: table, 7: footnote
for
model_page_info
in
self
.
__model_list
:
footnotes
=
[]
figures
=
[]
tables
=
[]
for
obj
in
model_page_info
[
'layout_dets'
]:
if
obj
[
'category_id'
]
==
7
:
footnotes
.
append
(
obj
)
elif
obj
[
'category_id'
]
==
3
:
figures
.
append
(
obj
)
elif
obj
[
'category_id'
]
==
5
:
tables
.
append
(
obj
)
if
len
(
footnotes
)
*
len
(
figures
)
==
0
:
continue
dis_figure_footnote
=
{}
dis_table_footnote
=
{}
for
i
in
range
(
len
(
footnotes
)):
for
j
in
range
(
len
(
figures
)):
pos_flag_count
=
sum
(
list
(
map
(
lambda
x
:
1
if
x
else
0
,
bbox_relative_pos
(
footnotes
[
i
][
'bbox'
],
figures
[
j
][
'bbox'
]
),
)
)
)
if
pos_flag_count
>
1
:
continue
dis_figure_footnote
[
i
]
=
min
(
self
.
_bbox_distance
(
figures
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]),
dis_figure_footnote
.
get
(
i
,
float
(
'inf'
)),
)
for
i
in
range
(
len
(
footnotes
)):
for
j
in
range
(
len
(
tables
)):
pos_flag_count
=
sum
(
list
(
map
(
lambda
x
:
1
if
x
else
0
,
bbox_relative_pos
(
footnotes
[
i
][
'bbox'
],
tables
[
j
][
'bbox'
]
),
)
)
)
if
pos_flag_count
>
1
:
continue
dis_table_footnote
[
i
]
=
min
(
self
.
_bbox_distance
(
tables
[
j
][
'bbox'
],
footnotes
[
i
][
'bbox'
]),
dis_table_footnote
.
get
(
i
,
float
(
'inf'
)),
)
for
i
in
range
(
len
(
footnotes
)):
if
i
not
in
dis_figure_footnote
:
continue
if
dis_table_footnote
.
get
(
i
,
float
(
'inf'
))
>
dis_figure_footnote
[
i
]:
footnotes
[
i
][
'category_id'
]
=
CategoryId
.
ImageFootnote
def
__reduct_overlap
(
self
,
bboxes
):
def
__reduct_overlap
(
self
,
bboxes
):
N
=
len
(
bboxes
)
N
=
len
(
bboxes
)
keep
=
[
True
]
*
N
keep
=
[
True
]
*
N
...
@@ -205,258 +160,10 @@ class MagicModel:
...
@@ -205,258 +160,10 @@ class MagicModel:
keep
[
i
]
=
False
keep
[
i
]
=
False
return
[
bboxes
[
i
]
for
i
in
range
(
N
)
if
keep
[
i
]]
return
[
bboxes
[
i
]
for
i
in
range
(
N
)
if
keep
[
i
]]
def
__tie_up_category_by_distance_v2
(
self
,
page_no
:
int
,
subject_category_id
:
int
,
object_category_id
:
int
,
priority_pos
:
PosRelationEnum
,
):
"""_summary_
Args:
page_no (int): _description_
subject_category_id (int): _description_
object_category_id (int): _description_
priority_pos (PosRelationEnum): _description_
Returns:
_type_: _description_
"""
AXIS_MULPLICITY
=
0.5
subjects
=
self
.
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
subject_category_id
,
self
.
__model_list
[
page_no
][
'layout_dets'
],
),
)
)
)
objects
=
self
.
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
object_category_id
,
self
.
__model_list
[
page_no
][
'layout_dets'
],
),
)
)
)
M
=
len
(
objects
)
subjects
.
sort
(
key
=
lambda
x
:
x
[
'bbox'
][
0
]
**
2
+
x
[
'bbox'
][
1
]
**
2
)
objects
.
sort
(
key
=
lambda
x
:
x
[
'bbox'
][
0
]
**
2
+
x
[
'bbox'
][
1
]
**
2
)
sub_obj_map_h
=
{
i
:
[]
for
i
in
range
(
len
(
subjects
))}
dis_by_directions
=
{
'top'
:
[[
-
1
,
float
(
'inf'
)]]
*
M
,
'bottom'
:
[[
-
1
,
float
(
'inf'
)]]
*
M
,
'left'
:
[[
-
1
,
float
(
'inf'
)]]
*
M
,
'right'
:
[[
-
1
,
float
(
'inf'
)]]
*
M
,
}
for
i
,
obj
in
enumerate
(
objects
):
l_x_axis
,
l_y_axis
=
(
obj
[
'bbox'
][
2
]
-
obj
[
'bbox'
][
0
],
obj
[
'bbox'
][
3
]
-
obj
[
'bbox'
][
1
],
)
axis_unit
=
min
(
l_x_axis
,
l_y_axis
)
for
j
,
sub
in
enumerate
(
subjects
):
bbox1
,
bbox2
,
_
=
_remove_overlap_between_bbox
(
objects
[
i
][
'bbox'
],
subjects
[
j
][
'bbox'
]
)
left
,
right
,
bottom
,
top
=
bbox_relative_pos
(
bbox1
,
bbox2
)
flags
=
[
left
,
right
,
bottom
,
top
]
if
sum
([
1
if
v
else
0
for
v
in
flags
])
>
1
:
continue
if
left
:
if
dis_by_directions
[
'left'
][
i
][
1
]
>
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]
):
dis_by_directions
[
'left'
][
i
]
=
[
j
,
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]),
]
if
right
:
if
dis_by_directions
[
'right'
][
i
][
1
]
>
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]
):
dis_by_directions
[
'right'
][
i
]
=
[
j
,
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]),
]
if
bottom
:
if
dis_by_directions
[
'bottom'
][
i
][
1
]
>
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]
):
dis_by_directions
[
'bottom'
][
i
]
=
[
j
,
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]),
]
if
top
:
if
dis_by_directions
[
'top'
][
i
][
1
]
>
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]
):
dis_by_directions
[
'top'
][
i
]
=
[
j
,
bbox_distance
(
obj
[
'bbox'
],
sub
[
'bbox'
]),
]
if
(
dis_by_directions
[
'top'
][
i
][
1
]
!=
float
(
'inf'
)
and
dis_by_directions
[
'bottom'
][
i
][
1
]
!=
float
(
'inf'
)
and
priority_pos
in
(
PosRelationEnum
.
BOTTOM
,
PosRelationEnum
.
UP
)
):
RATIO
=
3
if
(
abs
(
dis_by_directions
[
'top'
][
i
][
1
]
-
dis_by_directions
[
'bottom'
][
i
][
1
]
)
<
RATIO
*
axis_unit
):
if
priority_pos
==
PosRelationEnum
.
BOTTOM
:
sub_obj_map_h
[
dis_by_directions
[
'bottom'
][
i
][
0
]].
append
(
i
)
else
:
sub_obj_map_h
[
dis_by_directions
[
'top'
][
i
][
0
]].
append
(
i
)
continue
if
dis_by_directions
[
'left'
][
i
][
1
]
!=
float
(
'inf'
)
or
dis_by_directions
[
'right'
][
i
][
1
]
!=
float
(
'inf'
):
if
dis_by_directions
[
'left'
][
i
][
1
]
!=
float
(
'inf'
)
and
dis_by_directions
[
'right'
][
i
][
1
]
!=
float
(
'inf'
):
if
AXIS_MULPLICITY
*
axis_unit
>=
abs
(
dis_by_directions
[
'left'
][
i
][
1
]
-
dis_by_directions
[
'right'
][
i
][
1
]
):
left_sub_bbox
=
subjects
[
dis_by_directions
[
'left'
][
i
][
0
]][
'bbox'
]
right_sub_bbox
=
subjects
[
dis_by_directions
[
'right'
][
i
][
0
]][
'bbox'
]
left_sub_bbox_y_axis
=
left_sub_bbox
[
3
]
-
left_sub_bbox
[
1
]
right_sub_bbox_y_axis
=
right_sub_bbox
[
3
]
-
right_sub_bbox
[
1
]
if
(
abs
(
left_sub_bbox_y_axis
-
l_y_axis
)
+
dis_by_directions
[
'left'
][
i
][
0
]
>
abs
(
right_sub_bbox_y_axis
-
l_y_axis
)
+
dis_by_directions
[
'right'
][
i
][
0
]
):
left_or_right
=
dis_by_directions
[
'right'
][
i
]
else
:
left_or_right
=
dis_by_directions
[
'left'
][
i
]
else
:
left_or_right
=
dis_by_directions
[
'left'
][
i
]
if
left_or_right
[
1
]
>
dis_by_directions
[
'right'
][
i
][
1
]:
left_or_right
=
dis_by_directions
[
'right'
][
i
]
else
:
left_or_right
=
dis_by_directions
[
'left'
][
i
]
if
left_or_right
[
1
]
==
float
(
'inf'
):
left_or_right
=
dis_by_directions
[
'right'
][
i
]
else
:
left_or_right
=
[
-
1
,
float
(
'inf'
)]
if
dis_by_directions
[
'top'
][
i
][
1
]
!=
float
(
'inf'
)
or
dis_by_directions
[
'bottom'
][
i
][
1
]
!=
float
(
'inf'
):
if
dis_by_directions
[
'top'
][
i
][
1
]
!=
float
(
'inf'
)
and
dis_by_directions
[
'bottom'
][
i
][
1
]
!=
float
(
'inf'
):
if
AXIS_MULPLICITY
*
axis_unit
>=
abs
(
dis_by_directions
[
'top'
][
i
][
1
]
-
dis_by_directions
[
'bottom'
][
i
][
1
]
):
top_bottom
=
subjects
[
dis_by_directions
[
'bottom'
][
i
][
0
]][
'bbox'
]
bottom_top
=
subjects
[
dis_by_directions
[
'top'
][
i
][
0
]][
'bbox'
]
top_bottom_x_axis
=
top_bottom
[
2
]
-
top_bottom
[
0
]
bottom_top_x_axis
=
bottom_top
[
2
]
-
bottom_top
[
0
]
if
(
abs
(
top_bottom_x_axis
-
l_x_axis
)
+
dis_by_directions
[
'bottom'
][
i
][
1
]
>
abs
(
bottom_top_x_axis
-
l_x_axis
)
+
dis_by_directions
[
'top'
][
i
][
1
]
):
top_or_bottom
=
dis_by_directions
[
'top'
][
i
]
else
:
top_or_bottom
=
dis_by_directions
[
'bottom'
][
i
]
else
:
top_or_bottom
=
dis_by_directions
[
'top'
][
i
]
if
top_or_bottom
[
1
]
>
dis_by_directions
[
'bottom'
][
i
][
1
]:
top_or_bottom
=
dis_by_directions
[
'bottom'
][
i
]
else
:
top_or_bottom
=
dis_by_directions
[
'top'
][
i
]
if
top_or_bottom
[
1
]
==
float
(
'inf'
):
top_or_bottom
=
dis_by_directions
[
'bottom'
][
i
]
else
:
top_or_bottom
=
[
-
1
,
float
(
'inf'
)]
if
left_or_right
[
1
]
!=
float
(
'inf'
)
or
top_or_bottom
[
1
]
!=
float
(
'inf'
):
if
left_or_right
[
1
]
!=
float
(
'inf'
)
and
top_or_bottom
[
1
]
!=
float
(
'inf'
):
if
AXIS_MULPLICITY
*
axis_unit
>=
abs
(
left_or_right
[
1
]
-
top_or_bottom
[
1
]
):
y_axis_bbox
=
subjects
[
left_or_right
[
0
]][
'bbox'
]
x_axis_bbox
=
subjects
[
top_or_bottom
[
0
]][
'bbox'
]
if
(
abs
((
x_axis_bbox
[
2
]
-
x_axis_bbox
[
0
])
-
l_x_axis
)
/
l_x_axis
>
abs
((
y_axis_bbox
[
3
]
-
y_axis_bbox
[
1
])
-
l_y_axis
)
/
l_y_axis
):
sub_obj_map_h
[
left_or_right
[
0
]].
append
(
i
)
else
:
sub_obj_map_h
[
top_or_bottom
[
0
]].
append
(
i
)
else
:
if
left_or_right
[
1
]
>
top_or_bottom
[
1
]:
sub_obj_map_h
[
top_or_bottom
[
0
]].
append
(
i
)
else
:
sub_obj_map_h
[
left_or_right
[
0
]].
append
(
i
)
else
:
if
left_or_right
[
1
]
!=
float
(
'inf'
):
sub_obj_map_h
[
left_or_right
[
0
]].
append
(
i
)
else
:
sub_obj_map_h
[
top_or_bottom
[
0
]].
append
(
i
)
ret
=
[]
for
i
in
sub_obj_map_h
.
keys
():
ret
.
append
(
{
'sub_bbox'
:
{
'bbox'
:
subjects
[
i
][
'bbox'
],
'score'
:
subjects
[
i
][
'score'
],
},
'obj_bboxes'
:
[
{
'score'
:
objects
[
j
][
'score'
],
'bbox'
:
objects
[
j
][
'bbox'
]}
for
j
in
sub_obj_map_h
[
i
]
],
'sub_idx'
:
i
,
}
)
return
ret
def
__tie_up_category_by_distance_v3
(
def
__tie_up_category_by_distance_v3
(
self
,
self
,
page_no
:
int
,
subject_category_id
:
int
,
subject_category_id
:
int
,
object_category_id
:
int
,
object_category_id
:
int
,
priority_pos
:
PosRelationEnum
,
):
):
subjects
=
self
.
__reduct_overlap
(
subjects
=
self
.
__reduct_overlap
(
list
(
list
(
...
@@ -464,7 +171,7 @@ class MagicModel:
...
@@ -464,7 +171,7 @@ class MagicModel:
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
filter
(
lambda
x
:
x
[
'category_id'
]
==
subject_category_id
,
lambda
x
:
x
[
'category_id'
]
==
subject_category_id
,
self
.
__model_
list
[
page_no
]
[
'layout_dets'
],
self
.
__
page_
model_
info
[
'layout_dets'
],
),
),
)
)
)
)
...
@@ -475,7 +182,7 @@ class MagicModel:
...
@@ -475,7 +182,7 @@ class MagicModel:
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
filter
(
lambda
x
:
x
[
'category_id'
]
==
object_category_id
,
lambda
x
:
x
[
'category_id'
]
==
object_category_id
,
self
.
__model_
list
[
page_no
]
[
'layout_dets'
],
self
.
__
page_
model_
info
[
'layout_dets'
],
),
),
)
)
)
)
...
@@ -605,13 +312,12 @@ class MagicModel:
...
@@ -605,13 +312,12 @@ class MagicModel:
return
ret
return
ret
def
get_imgs
(
self
):
def
get_imgs_v2
(
self
,
page_no
:
int
):
with_captions
=
self
.
__tie_up_category_by_distance_v3
(
with_captions
=
self
.
__tie_up_category_by_distance_v3
(
page_no
,
3
,
4
,
PosRelationEnum
.
BOTTOM
3
,
4
)
)
with_footnotes
=
self
.
__tie_up_category_by_distance_v3
(
with_footnotes
=
self
.
__tie_up_category_by_distance_v3
(
page_no
,
3
,
CategoryId
.
ImageFootnote
,
PosRelationEnum
.
ALL
3
,
CategoryId
.
ImageFootnote
)
)
ret
=
[]
ret
=
[]
for
v
in
with_captions
:
for
v
in
with_captions
:
...
@@ -625,12 +331,12 @@ class MagicModel:
...
@@ -625,12 +331,12 @@ class MagicModel:
ret
.
append
(
record
)
ret
.
append
(
record
)
return
ret
return
ret
def
get_tables
_v2
(
self
,
page_no
:
int
)
->
list
:
def
get_tables
(
self
)
->
list
:
with_captions
=
self
.
__tie_up_category_by_distance_v3
(
with_captions
=
self
.
__tie_up_category_by_distance_v3
(
page_no
,
5
,
6
,
PosRelationEnum
.
UP
5
,
6
)
)
with_footnotes
=
self
.
__tie_up_category_by_distance_v3
(
with_footnotes
=
self
.
__tie_up_category_by_distance_v3
(
page_no
,
5
,
7
,
PosRelationEnum
.
ALL
5
,
7
)
)
ret
=
[]
ret
=
[]
for
v
in
with_captions
:
for
v
in
with_captions
:
...
@@ -644,52 +350,31 @@ class MagicModel:
...
@@ -644,52 +350,31 @@ class MagicModel:
ret
.
append
(
record
)
ret
.
append
(
record
)
return
ret
return
ret
def
get_imgs
(
self
,
page_no
:
int
):
def
get_equations
(
self
)
->
tuple
[
list
,
list
,
list
]:
# 有坐标,也有字
return
self
.
get_imgs_v2
(
page_no
)
def
get_tables
(
self
,
page_no
:
int
)
->
list
:
# 3个坐标, caption, table主体,table-note
return
self
.
get_tables_v2
(
page_no
)
def
get_equations
(
self
,
page_no
:
int
)
->
list
:
# 有坐标,也有字
inline_equations
=
self
.
__get_blocks_by_type
(
inline_equations
=
self
.
__get_blocks_by_type
(
ModelBlockTypeEnum
.
EMBEDDING
.
value
,
page_no
,
[
'latex'
]
CategoryId
.
InlineEquation
,
[
'latex'
]
)
)
interline_equations
=
self
.
__get_blocks_by_type
(
interline_equations
=
self
.
__get_blocks_by_type
(
ModelBlockTypeEnum
.
ISOLATED
.
value
,
page_no
,
[
'latex'
]
CategoryId
.
InterlineEquation_YOLO
,
[
'latex'
]
)
)
interline_equations_blocks
=
self
.
__get_blocks_by_type
(
interline_equations_blocks
=
self
.
__get_blocks_by_type
(
ModelBlockTypeEnum
.
ISOLATE_FORMULA
.
value
,
page_no
CategoryId
.
InterlineEquation_Layout
)
)
return
inline_equations
,
interline_equations
,
interline_equations_blocks
return
inline_equations
,
interline_equations
,
interline_equations_blocks
def
get_discarded
(
self
,
page_no
:
int
)
->
list
:
# 自研模型,只有坐标
def
get_discarded
(
self
)
->
list
:
# 自研模型,只有坐标
blocks
=
self
.
__get_blocks_by_type
(
ModelBlockTypeEnum
.
ABANDON
.
value
,
page_no
)
blocks
=
self
.
__get_blocks_by_type
(
CategoryId
.
Abandon
)
return
blocks
return
blocks
def
get_text_blocks
(
self
,
page_no
:
int
)
->
list
:
# 自研模型搞的,只有坐标,没有字
def
get_text_blocks
(
self
)
->
list
:
# 自研模型搞的,只有坐标,没有字
blocks
=
self
.
__get_blocks_by_type
(
ModelBlockTypeEnum
.
PLAIN_TEXT
.
value
,
page_no
)
blocks
=
self
.
__get_blocks_by_type
(
CategoryId
.
Text
)
return
blocks
return
blocks
def
get_title_blocks
(
self
,
page_no
:
int
)
->
list
:
# 自研模型,只有坐标,没字
def
get_title_blocks
(
self
)
->
list
:
# 自研模型,只有坐标,没字
blocks
=
self
.
__get_blocks_by_type
(
ModelBlockTypeEnum
.
TITLE
.
value
,
page_no
)
blocks
=
self
.
__get_blocks_by_type
(
CategoryId
.
Title
)
return
blocks
return
blocks
def
get_ocr_text
(
self
,
page_no
:
int
)
->
list
:
# paddle 搞的,有字也有坐标
def
get_all_spans
(
self
)
->
list
:
text_spans
=
[]
model_page_info
=
self
.
__model_list
[
page_no
]
layout_dets
=
model_page_info
[
'layout_dets'
]
for
layout_det
in
layout_dets
:
if
layout_det
[
'category_id'
]
==
'15'
:
span
=
{
'bbox'
:
layout_det
[
'bbox'
],
'content'
:
layout_det
[
'text'
],
}
text_spans
.
append
(
span
)
return
text_spans
def
get_all_spans
(
self
,
page_no
:
int
)
->
list
:
def
remove_duplicate_spans
(
spans
):
def
remove_duplicate_spans
(
spans
):
new_spans
=
[]
new_spans
=
[]
...
@@ -699,8 +384,7 @@ class MagicModel:
...
@@ -699,8 +384,7 @@ class MagicModel:
return
new_spans
return
new_spans
all_spans
=
[]
all_spans
=
[]
model_page_info
=
self
.
__model_list
[
page_no
]
layout_dets
=
self
.
__page_model_info
[
'layout_dets'
]
layout_dets
=
model_page_info
[
'layout_dets'
]
allow_category_id_list
=
[
3
,
5
,
13
,
14
,
15
]
allow_category_id_list
=
[
3
,
5
,
13
,
14
,
15
]
"""当成span拼接的"""
"""当成span拼接的"""
# 3: 'image', # 图片
# 3: 'image', # 图片
...
@@ -713,7 +397,7 @@ class MagicModel:
...
@@ -713,7 +397,7 @@ class MagicModel:
if
category_id
in
allow_category_id_list
:
if
category_id
in
allow_category_id_list
:
span
=
{
'bbox'
:
layout_det
[
'bbox'
],
'score'
:
layout_det
[
'score'
]}
span
=
{
'bbox'
:
layout_det
[
'bbox'
],
'score'
:
layout_det
[
'score'
]}
if
category_id
==
3
:
if
category_id
==
3
:
span
[
'type'
]
=
ContentType
.
I
mage
span
[
'type'
]
=
ContentType
.
I
MAGE
elif
category_id
==
5
:
elif
category_id
==
5
:
# 获取table模型结果
# 获取table模型结果
latex
=
layout_det
.
get
(
'latex'
,
None
)
latex
=
layout_det
.
get
(
'latex'
,
None
)
...
@@ -722,50 +406,36 @@ class MagicModel:
...
@@ -722,50 +406,36 @@ class MagicModel:
span
[
'latex'
]
=
latex
span
[
'latex'
]
=
latex
elif
html
:
elif
html
:
span
[
'html'
]
=
html
span
[
'html'
]
=
html
span
[
'type'
]
=
ContentType
.
T
able
span
[
'type'
]
=
ContentType
.
T
ABLE
elif
category_id
==
13
:
elif
category_id
==
13
:
span
[
'content'
]
=
layout_det
[
'latex'
]
span
[
'content'
]
=
layout_det
[
'latex'
]
span
[
'type'
]
=
ContentType
.
I
nlineEquation
span
[
'type'
]
=
ContentType
.
I
NLINE_EQUATION
elif
category_id
==
14
:
elif
category_id
==
14
:
span
[
'content'
]
=
layout_det
[
'latex'
]
span
[
'content'
]
=
layout_det
[
'latex'
]
span
[
'type'
]
=
ContentType
.
I
nterlineEquation
span
[
'type'
]
=
ContentType
.
I
NTERLINE_EQUATION
elif
category_id
==
15
:
elif
category_id
==
15
:
span
[
'content'
]
=
layout_det
[
'text'
]
span
[
'content'
]
=
layout_det
[
'text'
]
span
[
'type'
]
=
ContentType
.
T
ext
span
[
'type'
]
=
ContentType
.
T
EXT
all_spans
.
append
(
span
)
all_spans
.
append
(
span
)
return
remove_duplicate_spans
(
all_spans
)
return
remove_duplicate_spans
(
all_spans
)
def
get_page_size
(
self
,
page_no
:
int
):
# 获取页面宽高
# 获取当前页的page对象
page
=
self
.
__docs
.
get_page
(
page_no
).
get_page_info
()
# 获取当前页的宽高
page_w
=
page
.
w
page_h
=
page
.
h
return
page_w
,
page_h
def
__get_blocks_by_type
(
def
__get_blocks_by_type
(
self
,
type
:
int
,
page_no
:
int
,
extra_col
:
list
[
str
]
=
[]
self
,
category_type
:
int
,
extra_col
=
None
)
->
list
:
)
->
list
:
if
extra_col
is
None
:
extra_col
=
[]
blocks
=
[]
blocks
=
[]
for
page_dict
in
self
.
__model_list
:
layout_dets
=
self
.
__page_model_info
.
get
(
'layout_dets'
,
[])
layout_dets
=
page_dict
.
get
(
'layout_dets'
,
[])
for
item
in
layout_dets
:
page_info
=
page_dict
.
get
(
'page_info'
,
{})
category_id
=
item
.
get
(
'category_id'
,
-
1
)
page_number
=
page_info
.
get
(
'page_no'
,
-
1
)
bbox
=
item
.
get
(
'bbox'
,
None
)
if
page_no
!=
page_number
:
continue
if
category_id
==
category_type
:
for
item
in
layout_dets
:
block
=
{
category_id
=
item
.
get
(
'category_id'
,
-
1
)
'bbox'
:
bbox
,
bbox
=
item
.
get
(
'bbox'
,
None
)
'score'
:
item
.
get
(
'score'
),
}
if
category_id
==
type
:
for
col
in
extra_col
:
block
=
{
block
[
col
]
=
item
.
get
(
col
,
None
)
'bbox'
:
bbox
,
blocks
.
append
(
block
)
'score'
:
item
.
get
(
'score'
),
}
for
col
in
extra_col
:
block
[
col
]
=
item
.
get
(
col
,
None
)
blocks
.
append
(
block
)
return
blocks
return
blocks
def
get_model_list
(
self
,
page_no
):
return
self
.
__model_list
[
page_no
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment