Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
ea2f8ea0
Unverified
Commit
ea2f8ea0
authored
Nov 21, 2024
by
Xiaomeng Zhao
Committed by
GitHub
Nov 21, 2024
Browse files
Merge branch 'dev' into dev
parents
e4810cee
23c8436e
Changes
59
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
637 additions
and
460 deletions
+637
-460
magic_pdf/config/constants.py
magic_pdf/config/constants.py
+53
-0
magic_pdf/config/drop_reason.py
magic_pdf/config/drop_reason.py
+35
-0
magic_pdf/config/drop_tag.py
magic_pdf/config/drop_tag.py
+19
-0
magic_pdf/config/make_content_config.py
magic_pdf/config/make_content_config.py
+11
-0
magic_pdf/config/model_block_type.py
magic_pdf/config/model_block_type.py
+2
-1
magic_pdf/config/ocr_content_type.py
magic_pdf/config/ocr_content_type.py
+0
-0
magic_pdf/data/read_api.py
magic_pdf/data/read_api.py
+1
-1
magic_pdf/dict2md/mkcontent.py
magic_pdf/dict2md/mkcontent.py
+226
-185
magic_pdf/dict2md/ocr_mkcontent.py
magic_pdf/dict2md/ocr_mkcontent.py
+7
-8
magic_pdf/filter/pdf_meta_scan.py
magic_pdf/filter/pdf_meta_scan.py
+101
-79
magic_pdf/integrations/rag/utils.py
magic_pdf/integrations/rag/utils.py
+4
-5
magic_pdf/libs/MakeContentConfig.py
magic_pdf/libs/MakeContentConfig.py
+0
-11
magic_pdf/libs/config_reader.py
magic_pdf/libs/config_reader.py
+5
-5
magic_pdf/libs/draw_bbox.py
magic_pdf/libs/draw_bbox.py
+3
-2
magic_pdf/libs/drop_reason.py
magic_pdf/libs/drop_reason.py
+0
-27
magic_pdf/libs/drop_tag.py
magic_pdf/libs/drop_tag.py
+0
-19
magic_pdf/libs/pdf_image_tools.py
magic_pdf/libs/pdf_image_tools.py
+9
-11
magic_pdf/model/magic_model.py
magic_pdf/model/magic_model.py
+13
-13
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+109
-59
magic_pdf/model/sub_modules/model_init.py
magic_pdf/model/sub_modules/model_init.py
+39
-34
No files found.
magic_pdf/
libs/C
onstants.py
→
magic_pdf/
config/c
onstants.py
View file @
ea2f8ea0
"""
"""span维度自定义字段."""
span维度自定义字段
"""
# span是否是跨页合并的
# span是否是跨页合并的
CROSS_PAGE
=
"
cross_page
"
CROSS_PAGE
=
'
cross_page
'
"""
"""
block维度自定义字段
block维度自定义字段
"""
"""
# block中lines是否被删除
# block中lines是否被删除
LINES_DELETED
=
"
lines_deleted
"
LINES_DELETED
=
'
lines_deleted
'
# table recognition max time default value
# table recognition max time default value
TABLE_MAX_TIME_VALUE
=
400
TABLE_MAX_TIME_VALUE
=
400
...
@@ -17,39 +15,39 @@ TABLE_MAX_TIME_VALUE = 400
...
@@ -17,39 +15,39 @@ TABLE_MAX_TIME_VALUE = 400
TABLE_MAX_LEN
=
480
TABLE_MAX_LEN
=
480
# table master structure dict
# table master structure dict
TABLE_MASTER_DICT
=
"
table_master_structure_dict.txt
"
TABLE_MASTER_DICT
=
'
table_master_structure_dict.txt
'
# table master dir
# table master dir
TABLE_MASTER_DIR
=
"
table_structure_tablemaster_infer/
"
TABLE_MASTER_DIR
=
'
table_structure_tablemaster_infer/
'
# pp detect model dir
# pp detect model dir
DETECT_MODEL_DIR
=
"
ch_PP-OCRv4_det_infer
"
DETECT_MODEL_DIR
=
'
ch_PP-OCRv4_det_infer
'
# pp rec model dir
# pp rec model dir
REC_MODEL_DIR
=
"
ch_PP-OCRv4_rec_infer
"
REC_MODEL_DIR
=
'
ch_PP-OCRv4_rec_infer
'
# pp rec char dict path
# pp rec char dict path
REC_CHAR_DICT
=
"
ppocr_keys_v1.txt
"
REC_CHAR_DICT
=
'
ppocr_keys_v1.txt
'
# pp rec copy rec directory
# pp rec copy rec directory
PP_REC_DIRECTORY
=
"
.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer
"
PP_REC_DIRECTORY
=
'
.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer
'
# pp rec copy det directory
# pp rec copy det directory
PP_DET_DIRECTORY
=
"
.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer
"
PP_DET_DIRECTORY
=
'
.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer
'
class
MODEL_NAME
:
class
MODEL_NAME
:
# pp table structure algorithm
# pp table structure algorithm
TABLE_MASTER
=
"
tablemaster
"
TABLE_MASTER
=
'
tablemaster
'
# struct eqtable
# struct eqtable
STRUCT_EQTABLE
=
"
struct_eqtable
"
STRUCT_EQTABLE
=
'
struct_eqtable
'
DocLayout_YOLO
=
"
doclayout_yolo
"
DocLayout_YOLO
=
'
doclayout_yolo
'
LAYOUTLMv3
=
"
layoutlmv3
"
LAYOUTLMv3
=
'
layoutlmv3
'
YOLO_V8_MFD
=
"
yolo_v8_mfd
"
YOLO_V8_MFD
=
'
yolo_v8_mfd
'
UniMerNet_v2_Small
=
"
unimernet_small
"
UniMerNet_v2_Small
=
'
unimernet_small
'
RAPID_TABLE
=
"rapid_table"
RAPID_TABLE
=
'rapid_table'
\ No newline at end of file
magic_pdf/config/drop_reason.py
0 → 100644
View file @
ea2f8ea0
class
DropReason
:
TEXT_BLCOK_HOR_OVERLAP
=
'text_block_horizontal_overlap'
# 文字块有水平互相覆盖,导致无法准确定位文字顺序
USEFUL_BLOCK_HOR_OVERLAP
=
(
'useful_block_horizontal_overlap'
# 需保留的block水平覆盖
)
COMPLICATED_LAYOUT
=
'complicated_layout'
# 复杂的布局,暂时不支持
TOO_MANY_LAYOUT_COLUMNS
=
'too_many_layout_columns'
# 目前不支持分栏超过2列的
COLOR_BACKGROUND_TEXT_BOX
=
'color_background_text_box'
# 含有带色块的PDF,色块会改变阅读顺序,目前不支持带底色文字块的PDF。
HIGH_COMPUTATIONAL_lOAD_BY_IMGS
=
(
'high_computational_load_by_imgs'
# 含特殊图片,计算量太大,从而丢弃
)
HIGH_COMPUTATIONAL_lOAD_BY_SVGS
=
(
'high_computational_load_by_svgs'
# 特殊的SVG图,计算量太大,从而丢弃
)
HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES
=
'high_computational_load_by_total_pages'
# 计算量超过负荷,当前方法下计算量消耗过大
MISS_DOC_LAYOUT_RESULT
=
'missing doc_layout_result'
# 版面分析失败
Exception
=
'_exception'
# 解析中发生异常
ENCRYPTED
=
'encrypted'
# PDF是加密的
EMPTY_PDF
=
'total_page=0'
# PDF页面总数为0
NOT_IS_TEXT_PDF
=
'not_is_text_pdf'
# 不是文字版PDF,无法直接解析
DENSE_SINGLE_LINE_BLOCK
=
'dense_single_line_block'
# 无法清晰的分段
TITLE_DETECTION_FAILED
=
'title_detection_failed'
# 探测标题失败
TITLE_LEVEL_FAILED
=
(
'title_level_failed'
# 分析标题级别失败(例如一级、二级、三级标题)
)
PARA_SPLIT_FAILED
=
'para_split_failed'
# 识别段落失败
PARA_MERGE_FAILED
=
'para_merge_failed'
# 段落合并失败
NOT_ALLOW_LANGUAGE
=
'not_allow_language'
# 不支持的语种
SPECIAL_PDF
=
'special_pdf'
PSEUDO_SINGLE_COLUMN
=
'pseudo_single_column'
# 无法精确判断文字分栏
CAN_NOT_DETECT_PAGE_LAYOUT
=
'can_not_detect_page_layout'
# 无法分析页面的版面
NEGATIVE_BBOX_AREA
=
'negative_bbox_area'
# 缩放导致 bbox 面积为负
OVERLAP_BLOCKS_CAN_NOT_SEPARATION
=
(
'overlap_blocks_can_t_separation'
# 无法分离重叠的block
)
magic_pdf/config/drop_tag.py
0 → 100644
View file @
ea2f8ea0
COLOR_BG_HEADER_TXT_BLOCK
=
'color_background_header_txt_block'
PAGE_NO
=
'page-no'
# 页码
CONTENT_IN_FOOT_OR_HEADER
=
'in-foot-header-area'
# 页眉页脚内的文本
VERTICAL_TEXT
=
'vertical-text'
# 垂直文本
ROTATE_TEXT
=
'rotate-text'
# 旋转文本
EMPTY_SIDE_BLOCK
=
'empty-side-block'
# 边缘上的空白没有任何内容的block
ON_IMAGE_TEXT
=
'on-image-text'
# 文本在图片上
ON_TABLE_TEXT
=
'on-table-text'
# 文本在表格上
class
DropTag
:
PAGE_NUMBER
=
'page_no'
HEADER
=
'header'
FOOTER
=
'footer'
FOOTNOTE
=
'footnote'
NOT_IN_LAYOUT
=
'not_in_layout'
SPAN_OVERLAP
=
'span_overlap'
BLOCK_OVERLAP
=
'block_overlap'
magic_pdf/config/make_content_config.py
0 → 100644
View file @
ea2f8ea0
class
MakeMode
:
MM_MD
=
'mm_markdown'
NLP_MD
=
'nlp_markdown'
STANDARD_FORMAT
=
'standard_format'
class
DropMode
:
WHOLE_PDF
=
'whole_pdf'
SINGLE_PAGE
=
'single_page'
NONE
=
'none'
NONE_WITH_REASON
=
'none_with_reason'
magic_pdf/
libs/M
odel
B
lock
T
ype
Enum
.py
→
magic_pdf/
config/m
odel
_b
lock
_t
ype.py
View file @
ea2f8ea0
from
enum
import
Enum
from
enum
import
Enum
class
ModelBlockTypeEnum
(
Enum
):
class
ModelBlockTypeEnum
(
Enum
):
TITLE
=
0
TITLE
=
0
PLAIN_TEXT
=
1
PLAIN_TEXT
=
1
ABANDON
=
2
ABANDON
=
2
ISOLATE_FORMULA
=
8
ISOLATE_FORMULA
=
8
EMBEDDING
=
13
EMBEDDING
=
13
ISOLATED
=
14
ISOLATED
=
14
\ No newline at end of file
magic_pdf/
libs
/ocr_content_type.py
→
magic_pdf/
config
/ocr_content_type.py
View file @
ea2f8ea0
File moved
magic_pdf/data/read_api.py
View file @
ea2f8ea0
...
@@ -35,7 +35,7 @@ def read_jsonl(
...
@@ -35,7 +35,7 @@ def read_jsonl(
jsonl_d
=
[
jsonl_d
=
[
json
.
loads
(
line
)
for
line
in
jsonl_bits
.
decode
().
split
(
'
\n
'
)
if
line
.
strip
()
json
.
loads
(
line
)
for
line
in
jsonl_bits
.
decode
().
split
(
'
\n
'
)
if
line
.
strip
()
]
]
for
d
in
jsonl_d
[:
5
]
:
for
d
in
jsonl_d
:
pdf_path
=
d
.
get
(
'file_location'
,
''
)
or
d
.
get
(
'path'
,
''
)
pdf_path
=
d
.
get
(
'file_location'
,
''
)
or
d
.
get
(
'path'
,
''
)
if
len
(
pdf_path
)
==
0
:
if
len
(
pdf_path
)
==
0
:
raise
EmptyData
(
'pdf file location is empty'
)
raise
EmptyData
(
'pdf file location is empty'
)
...
...
magic_pdf/dict2md/mkcontent.py
View file @
ea2f8ea0
import
math
import
math
from
loguru
import
logger
from
loguru
import
logger
from
magic_pdf.libs.boxbase
import
find_bottom_nearest_text_bbox
,
find_top_nearest_text_bbox
from
magic_pdf.config.ocr_content_type
import
ContentType
from
magic_pdf.libs.boxbase
import
(
find_bottom_nearest_text_bbox
,
find_top_nearest_text_bbox
)
from
magic_pdf.libs.commons
import
join_path
from
magic_pdf.libs.commons
import
join_path
from
magic_pdf.libs.ocr_content_type
import
ContentType
TYPE_INLINE_EQUATION
=
ContentType
.
InlineEquation
TYPE_INLINE_EQUATION
=
ContentType
.
InlineEquation
TYPE_INTERLINE_EQUATION
=
ContentType
.
InterlineEquation
TYPE_INTERLINE_EQUATION
=
ContentType
.
InterlineEquation
...
@@ -12,33 +14,30 @@ UNI_FORMAT_TEXT_TYPE = ['text', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6']
...
@@ -12,33 +14,30 @@ UNI_FORMAT_TEXT_TYPE = ['text', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6']
@
DeprecationWarning
@
DeprecationWarning
def
mk_nlp_markdown_1
(
para_dict
:
dict
):
def
mk_nlp_markdown_1
(
para_dict
:
dict
):
"""
"""对排序后的bboxes拼接内容."""
对排序后的bboxes拼接内容
"""
content_lst
=
[]
content_lst
=
[]
for
_
,
page_info
in
para_dict
.
items
():
for
_
,
page_info
in
para_dict
.
items
():
para_blocks
=
page_info
.
get
(
"
para_blocks
"
)
para_blocks
=
page_info
.
get
(
'
para_blocks
'
)
if
not
para_blocks
:
if
not
para_blocks
:
continue
continue
for
block
in
para_blocks
:
for
block
in
para_blocks
:
item
=
block
[
"
paras
"
]
item
=
block
[
'
paras
'
]
for
_
,
p
in
item
.
items
():
for
_
,
p
in
item
.
items
():
para_text
=
p
[
"
para_text
"
]
para_text
=
p
[
'
para_text
'
]
is_title
=
p
[
"
is_para_title
"
]
is_title
=
p
[
'
is_para_title
'
]
title_level
=
p
[
'para_title_level'
]
title_level
=
p
[
'para_title_level'
]
md_title_prefix
=
"#"
*
title_level
md_title_prefix
=
'#'
*
title_level
if
is_title
:
if
is_title
:
content_lst
.
append
(
f
"
{
md_title_prefix
}
{
para_text
}
"
)
content_lst
.
append
(
f
'
{
md_title_prefix
}
{
para_text
}
'
)
else
:
else
:
content_lst
.
append
(
para_text
)
content_lst
.
append
(
para_text
)
content_text
=
"
\n\n
"
.
join
(
content_lst
)
content_text
=
'
\n\n
'
.
join
(
content_lst
)
return
content_text
return
content_text
# 找到目标字符串在段落中的索引
# 找到目标字符串在段落中的索引
def
__find_index
(
paragraph
,
target
):
def
__find_index
(
paragraph
,
target
):
index
=
paragraph
.
find
(
target
)
index
=
paragraph
.
find
(
target
)
...
@@ -48,69 +47,76 @@ def __find_index(paragraph, target):
...
@@ -48,69 +47,76 @@ def __find_index(paragraph, target):
return
None
return
None
def
__insert_string
(
paragraph
,
target
,
postion
):
def
__insert_string
(
paragraph
,
target
,
pos
i
tion
):
new_paragraph
=
paragraph
[:
postion
]
+
target
+
paragraph
[
postion
:]
new_paragraph
=
paragraph
[:
pos
i
tion
]
+
target
+
paragraph
[
pos
i
tion
:]
return
new_paragraph
return
new_paragraph
def
__insert_after
(
content
,
image_content
,
target
):
def
__insert_after
(
content
,
image_content
,
target
):
"""
"""在content中找到target,将image_content插入到target后面."""
在content中找到target,将image_content插入到target后面
"""
index
=
content
.
find
(
target
)
index
=
content
.
find
(
target
)
if
index
!=
-
1
:
if
index
!=
-
1
:
content
=
content
[:
index
+
len
(
target
)]
+
"
\n\n
"
+
image_content
+
"
\n\n
"
+
content
[
index
+
len
(
target
):]
content
=
(
content
[:
index
+
len
(
target
)]
+
'
\n\n
'
+
image_content
+
'
\n\n
'
+
content
[
index
+
len
(
target
)
:]
)
else
:
else
:
logger
.
error
(
f
"Can't find the location of image
{
image_content
}
in the markdown file, search target is
{
target
}
"
)
logger
.
error
(
f
"Can't find the location of image
{
image_content
}
in the markdown file, search target is
{
target
}
"
)
return
content
return
content
def
__insert_before
(
content
,
image_content
,
target
):
def
__insert_before
(
content
,
image_content
,
target
):
"""
"""在content中找到target,将image_content插入到target前面."""
在content中找到target,将image_content插入到target前面
"""
index
=
content
.
find
(
target
)
index
=
content
.
find
(
target
)
if
index
!=
-
1
:
if
index
!=
-
1
:
content
=
content
[:
index
]
+
"
\n\n
"
+
image_content
+
"
\n\n
"
+
content
[
index
:]
content
=
content
[:
index
]
+
'
\n\n
'
+
image_content
+
'
\n\n
'
+
content
[
index
:]
else
:
else
:
logger
.
error
(
f
"Can't find the location of image
{
image_content
}
in the markdown file, search target is
{
target
}
"
)
logger
.
error
(
f
"Can't find the location of image
{
image_content
}
in the markdown file, search target is
{
target
}
"
)
return
content
return
content
@
DeprecationWarning
@
DeprecationWarning
def
mk_mm_markdown_1
(
para_dict
:
dict
):
def
mk_mm_markdown_1
(
para_dict
:
dict
):
"""拼装多模态markdown"""
"""拼装多模态markdown
.
"""
content_lst
=
[]
content_lst
=
[]
for
_
,
page_info
in
para_dict
.
items
():
for
_
,
page_info
in
para_dict
.
items
():
page_lst
=
[]
# 一个page内的段落列表
page_lst
=
[]
# 一个page内的段落列表
para_blocks
=
page_info
.
get
(
"
para_blocks
"
)
para_blocks
=
page_info
.
get
(
'
para_blocks
'
)
pymu_raw_blocks
=
page_info
.
get
(
"
preproc_blocks
"
)
pymu_raw_blocks
=
page_info
.
get
(
'
preproc_blocks
'
)
all_page_images
=
[]
all_page_images
=
[]
all_page_images
.
extend
(
page_info
.
get
(
"
images
"
,
[]))
all_page_images
.
extend
(
page_info
.
get
(
'
images
'
,
[]))
all_page_images
.
extend
(
page_info
.
get
(
"
image_backup
"
,
[])
)
all_page_images
.
extend
(
page_info
.
get
(
'
image_backup
'
,
[]))
all_page_images
.
extend
(
page_info
.
get
(
"
tables
"
,
[]))
all_page_images
.
extend
(
page_info
.
get
(
'
tables
'
,
[]))
all_page_images
.
extend
(
page_info
.
get
(
"
table_backup
"
,
[])
)
all_page_images
.
extend
(
page_info
.
get
(
'
table_backup
'
,
[]))
if
not
para_blocks
or
not
pymu_raw_blocks
:
# 只有图片的拼接的场景
if
not
para_blocks
or
not
pymu_raw_blocks
:
# 只有图片的拼接的场景
for
img
in
all_page_images
:
for
img
in
all_page_images
:
page_lst
.
append
(
f
""
)
# TODO 图片顺序
page_lst
.
append
(
f
""
)
# TODO 图片顺序
page_md
=
"
\n\n
"
.
join
(
page_lst
)
page_md
=
'
\n\n
'
.
join
(
page_lst
)
else
:
else
:
for
block
in
para_blocks
:
for
block
in
para_blocks
:
item
=
block
[
"
paras
"
]
item
=
block
[
'
paras
'
]
for
_
,
p
in
item
.
items
():
for
_
,
p
in
item
.
items
():
para_text
=
p
[
"
para_text
"
]
para_text
=
p
[
'
para_text
'
]
is_title
=
p
[
"
is_para_title
"
]
is_title
=
p
[
'
is_para_title
'
]
title_level
=
p
[
'para_title_level'
]
title_level
=
p
[
'para_title_level'
]
md_title_prefix
=
"#"
*
title_level
md_title_prefix
=
'#'
*
title_level
if
is_title
:
if
is_title
:
page_lst
.
append
(
f
"
{
md_title_prefix
}
{
para_text
}
"
)
page_lst
.
append
(
f
'
{
md_title_prefix
}
{
para_text
}
'
)
else
:
else
:
page_lst
.
append
(
para_text
)
page_lst
.
append
(
para_text
)
"""拼装成一个页面的文本"""
"""拼装成一个页面的文本"""
page_md
=
"
\n\n
"
.
join
(
page_lst
)
page_md
=
'
\n\n
'
.
join
(
page_lst
)
"""插入图片"""
"""插入图片"""
for
img
in
all_page_images
:
for
img
in
all_page_images
:
imgbox
=
img
[
'bbox'
]
imgbox
=
img
[
'bbox'
]
...
@@ -118,192 +124,215 @@ def mk_mm_markdown_1(para_dict: dict):
...
@@ -118,192 +124,215 @@ def mk_mm_markdown_1(para_dict: dict):
# 先看在哪个block内
# 先看在哪个block内
for
block
in
pymu_raw_blocks
:
for
block
in
pymu_raw_blocks
:
bbox
=
block
[
'bbox'
]
bbox
=
block
[
'bbox'
]
if
bbox
[
0
]
-
1
<=
imgbox
[
0
]
<
bbox
[
2
]
+
1
and
bbox
[
1
]
-
1
<=
imgbox
[
1
]
<
bbox
[
3
]
+
1
:
# 确定在block内
if
(
for
l
in
block
[
'lines'
]:
bbox
[
0
]
-
1
<=
imgbox
[
0
]
<
bbox
[
2
]
+
1
and
bbox
[
1
]
-
1
<=
imgbox
[
1
]
<
bbox
[
3
]
+
1
):
# 确定在block内
for
l
in
block
[
'lines'
]:
# noqa: E741
line_box
=
l
[
'bbox'
]
line_box
=
l
[
'bbox'
]
if
line_box
[
0
]
-
1
<=
imgbox
[
0
]
<
line_box
[
2
]
+
1
and
line_box
[
1
]
-
1
<=
imgbox
[
1
]
<
line_box
[
3
]
+
1
:
# 在line内的,插入line前面
if
(
line_txt
=
""
.
join
([
s
[
'text'
]
for
s
in
l
[
'spans'
]])
line_box
[
0
]
-
1
<=
imgbox
[
0
]
<
line_box
[
2
]
+
1
page_md
=
__insert_before
(
page_md
,
img_content
,
line_txt
)
and
line_box
[
1
]
-
1
<=
imgbox
[
1
]
<
line_box
[
3
]
+
1
):
# 在line内的,插入line前面
line_txt
=
''
.
join
([
s
[
'text'
]
for
s
in
l
[
'spans'
]])
page_md
=
__insert_before
(
page_md
,
img_content
,
line_txt
)
break
break
break
break
else
:
# 在行与行之间
else
:
# 在行与行之间
# 找到图片x0,y0与line的x0,y0最近的line
# 找到图片x0,y0与line的x0,y0最近的line
min_distance
=
100000
min_distance
=
100000
min_line
=
None
min_line
=
None
for
l
in
block
[
'lines'
]:
for
l
in
block
[
'lines'
]:
# noqa: E741
line_box
=
l
[
'bbox'
]
line_box
=
l
[
'bbox'
]
distance
=
math
.
sqrt
((
line_box
[
0
]
-
imgbox
[
0
])
**
2
+
(
line_box
[
1
]
-
imgbox
[
1
])
**
2
)
distance
=
math
.
sqrt
(
(
line_box
[
0
]
-
imgbox
[
0
])
**
2
+
(
line_box
[
1
]
-
imgbox
[
1
])
**
2
)
if
distance
<
min_distance
:
if
distance
<
min_distance
:
min_distance
=
distance
min_distance
=
distance
min_line
=
l
min_line
=
l
if
min_line
:
if
min_line
:
line_txt
=
""
.
join
([
s
[
'text'
]
for
s
in
min_line
[
'spans'
]])
line_txt
=
''
.
join
(
[
s
[
'text'
]
for
s
in
min_line
[
'spans'
]]
)
img_h
=
imgbox
[
3
]
-
imgbox
[
1
]
img_h
=
imgbox
[
3
]
-
imgbox
[
1
]
if
min_distance
<
img_h
:
# 文字在图片前面
if
min_distance
<
img_h
:
# 文字在图片前面
page_md
=
__insert_after
(
page_md
,
img_content
,
line_txt
)
page_md
=
__insert_after
(
page_md
,
img_content
,
line_txt
)
else
:
else
:
page_md
=
__insert_before
(
page_md
,
img_content
,
line_txt
)
page_md
=
__insert_before
(
page_md
,
img_content
,
line_txt
)
else
:
else
:
logger
.
error
(
f
"Can't find the location of image
{
img
[
'image_path'
]
}
in the markdown file #1"
)
logger
.
error
(
else
:
# 应当在两个block之间
f
"Can't find the location of image
{
img
[
'image_path'
]
}
in the markdown file #1"
)
else
:
# 应当在两个block之间
# 找到上方最近的block,如果上方没有就找大下方最近的block
# 找到上方最近的block,如果上方没有就找大下方最近的block
top_txt_block
=
find_top_nearest_text_bbox
(
pymu_raw_blocks
,
imgbox
)
top_txt_block
=
find_top_nearest_text_bbox
(
pymu_raw_blocks
,
imgbox
)
if
top_txt_block
:
if
top_txt_block
:
line_txt
=
""
.
join
([
s
[
'text'
]
for
s
in
top_txt_block
[
'lines'
][
-
1
][
'spans'
]])
line_txt
=
''
.
join
(
[
s
[
'text'
]
for
s
in
top_txt_block
[
'lines'
][
-
1
][
'spans'
]]
)
page_md
=
__insert_after
(
page_md
,
img_content
,
line_txt
)
page_md
=
__insert_after
(
page_md
,
img_content
,
line_txt
)
else
:
else
:
bottom_txt_block
=
find_bottom_nearest_text_bbox
(
pymu_raw_blocks
,
imgbox
)
bottom_txt_block
=
find_bottom_nearest_text_bbox
(
pymu_raw_blocks
,
imgbox
)
if
bottom_txt_block
:
if
bottom_txt_block
:
line_txt
=
""
.
join
([
s
[
'text'
]
for
s
in
bottom_txt_block
[
'lines'
][
0
][
'spans'
]])
line_txt
=
''
.
join
(
[
s
[
'text'
]
for
s
in
bottom_txt_block
[
'lines'
][
0
][
'spans'
]
]
)
page_md
=
__insert_before
(
page_md
,
img_content
,
line_txt
)
page_md
=
__insert_before
(
page_md
,
img_content
,
line_txt
)
else
:
else
:
logger
.
error
(
f
"Can't find the location of image
{
img
[
'image_path'
]
}
in the markdown file #2"
)
logger
.
error
(
f
"Can't find the location of image
{
img
[
'image_path'
]
}
in the markdown file #2"
)
content_lst
.
append
(
page_md
)
content_lst
.
append
(
page_md
)
"""拼装成全部页面的文本"""
"""拼装成全部页面的文本"""
content_text
=
"
\n\n
"
.
join
(
content_lst
)
content_text
=
'
\n\n
'
.
join
(
content_lst
)
return
content_text
return
content_text
def
__insert_after_para
(
text
,
type
,
element
,
content_list
):
def
__insert_after_para
(
text
,
type
,
element
,
content_list
):
"""
"""在content_list中找到text,将image_path作为一个新的node插入到text后面."""
在content_list中找到text,将image_path作为一个新的node插入到text后面
"""
for
i
,
c
in
enumerate
(
content_list
):
for
i
,
c
in
enumerate
(
content_list
):
content_type
=
c
.
get
(
"
type
"
)
content_type
=
c
.
get
(
'
type
'
)
if
content_type
in
UNI_FORMAT_TEXT_TYPE
and
text
in
c
.
get
(
"
text
"
,
''
):
if
content_type
in
UNI_FORMAT_TEXT_TYPE
and
text
in
c
.
get
(
'
text
'
,
''
):
if
type
==
"
image
"
:
if
type
==
'
image
'
:
content_node
=
{
content_node
=
{
"
type
"
:
"
image
"
,
'
type
'
:
'
image
'
,
"
img_path
"
:
element
.
get
(
"
image_path
"
),
'
img_path
'
:
element
.
get
(
'
image_path
'
),
"
img_alt
"
:
""
,
'
img_alt
'
:
''
,
"
img_title
"
:
""
,
'
img_title
'
:
''
,
"
img_caption
"
:
""
,
'
img_caption
'
:
''
,
}
}
elif
type
==
"
table
"
:
elif
type
==
'
table
'
:
content_node
=
{
content_node
=
{
"
type
"
:
"
table
"
,
'
type
'
:
'
table
'
,
"
img_path
"
:
element
.
get
(
"
image_path
"
),
'
img_path
'
:
element
.
get
(
'
image_path
'
),
"
table_latex
"
:
element
.
get
(
"
text
"
),
'
table_latex
'
:
element
.
get
(
'
text
'
),
"
table_title
"
:
""
,
'
table_title
'
:
''
,
"
table_caption
"
:
""
,
'
table_caption
'
:
''
,
"
table_quality
"
:
element
.
get
(
"
quality
"
),
'
table_quality
'
:
element
.
get
(
'
quality
'
),
}
}
content_list
.
insert
(
i
+
1
,
content_node
)
content_list
.
insert
(
i
+
1
,
content_node
)
break
break
else
:
else
:
logger
.
error
(
f
"Can't find the location of image
{
element
.
get
(
'image_path'
)
}
in the markdown file, search target is
{
text
}
"
)
logger
.
error
(
f
"Can't find the location of image
{
element
.
get
(
'image_path'
)
}
in the markdown file, search target is
{
text
}
"
)
def
__insert_before_para
(
text
,
type
,
element
,
content_list
):
def
__insert_before_para
(
text
,
type
,
element
,
content_list
):
"""
"""在content_list中找到text,将image_path作为一个新的node插入到text前面."""
在content_list中找到text,将image_path作为一个新的node插入到text前面
"""
for
i
,
c
in
enumerate
(
content_list
):
for
i
,
c
in
enumerate
(
content_list
):
content_type
=
c
.
get
(
"
type
"
)
content_type
=
c
.
get
(
'
type
'
)
if
content_type
in
UNI_FORMAT_TEXT_TYPE
and
text
in
c
.
get
(
"
text
"
,
''
):
if
content_type
in
UNI_FORMAT_TEXT_TYPE
and
text
in
c
.
get
(
'
text
'
,
''
):
if
type
==
"
image
"
:
if
type
==
'
image
'
:
content_node
=
{
content_node
=
{
"
type
"
:
"
image
"
,
'
type
'
:
'
image
'
,
"
img_path
"
:
element
.
get
(
"
image_path
"
),
'
img_path
'
:
element
.
get
(
'
image_path
'
),
"
img_alt
"
:
""
,
'
img_alt
'
:
''
,
"
img_title
"
:
""
,
'
img_title
'
:
''
,
"
img_caption
"
:
""
,
'
img_caption
'
:
''
,
}
}
elif
type
==
"
table
"
:
elif
type
==
'
table
'
:
content_node
=
{
content_node
=
{
"
type
"
:
"
table
"
,
'
type
'
:
'
table
'
,
"
img_path
"
:
element
.
get
(
"
image_path
"
),
'
img_path
'
:
element
.
get
(
'
image_path
'
),
"
table_latex
"
:
element
.
get
(
"
text
"
),
'
table_latex
'
:
element
.
get
(
'
text
'
),
"
table_title
"
:
""
,
'
table_title
'
:
''
,
"
table_caption
"
:
""
,
'
table_caption
'
:
''
,
"
table_quality
"
:
element
.
get
(
"
quality
"
),
'
table_quality
'
:
element
.
get
(
'
quality
'
),
}
}
content_list
.
insert
(
i
,
content_node
)
content_list
.
insert
(
i
,
content_node
)
break
break
else
:
else
:
logger
.
error
(
f
"Can't find the location of image
{
element
.
get
(
'image_path'
)
}
in the markdown file, search target is
{
text
}
"
)
logger
.
error
(
f
"Can't find the location of image
{
element
.
get
(
'image_path'
)
}
in the markdown file, search target is
{
text
}
"
)
def
mk_universal_format
(
pdf_info_list
:
list
,
img_buket_path
):
def
mk_universal_format
(
pdf_info_list
:
list
,
img_buket_path
):
"""
"""构造统一格式 https://aicarrier.feishu.cn/wiki/FqmMwcH69iIdCWkkyjvcDwNUnTY."""
构造统一格式 https://aicarrier.feishu.cn/wiki/FqmMwcH69iIdCWkkyjvcDwNUnTY
"""
content_lst
=
[]
content_lst
=
[]
for
page_info
in
pdf_info_list
:
for
page_info
in
pdf_info_list
:
page_lst
=
[]
# 一个page内的段落列表
page_lst
=
[]
# 一个page内的段落列表
para_blocks
=
page_info
.
get
(
"
para_blocks
"
)
para_blocks
=
page_info
.
get
(
'
para_blocks
'
)
pymu_raw_blocks
=
page_info
.
get
(
"
preproc_blocks
"
)
pymu_raw_blocks
=
page_info
.
get
(
'
preproc_blocks
'
)
all_page_images
=
[]
all_page_images
=
[]
all_page_images
.
extend
(
page_info
.
get
(
"
images
"
,
[]))
all_page_images
.
extend
(
page_info
.
get
(
'
images
'
,
[]))
all_page_images
.
extend
(
page_info
.
get
(
"
image_backup
"
,
[])
)
all_page_images
.
extend
(
page_info
.
get
(
'
image_backup
'
,
[]))
# all_page_images.extend(page_info.get("tables",[]))
# all_page_images.extend(page_info.get("tables",[]))
# all_page_images.extend(page_info.get("table_backup",[]) )
# all_page_images.extend(page_info.get("table_backup",[]) )
all_page_tables
=
[]
all_page_tables
=
[]
all_page_tables
.
extend
(
page_info
.
get
(
"
tables
"
,
[]))
all_page_tables
.
extend
(
page_info
.
get
(
'
tables
'
,
[]))
if
not
para_blocks
or
not
pymu_raw_blocks
:
# 只有图片的拼接的场景
if
not
para_blocks
or
not
pymu_raw_blocks
:
# 只有图片的拼接的场景
for
img
in
all_page_images
:
for
img
in
all_page_images
:
content_node
=
{
content_node
=
{
"
type
"
:
"
image
"
,
'
type
'
:
'
image
'
,
"
img_path
"
:
join_path
(
img_buket_path
,
img
[
'image_path'
]),
'
img_path
'
:
join_path
(
img_buket_path
,
img
[
'image_path'
]),
"
img_alt
"
:
""
,
'
img_alt
'
:
''
,
"
img_title
"
:
""
,
'
img_title
'
:
''
,
"
img_caption
"
:
""
'
img_caption
'
:
''
,
}
}
page_lst
.
append
(
content_node
)
# TODO 图片顺序
page_lst
.
append
(
content_node
)
# TODO 图片顺序
for
table
in
all_page_tables
:
for
table
in
all_page_tables
:
content_node
=
{
content_node
=
{
"
type
"
:
"
table
"
,
'
type
'
:
'
table
'
,
"
img_path
"
:
join_path
(
img_buket_path
,
table
[
'image_path'
]),
'
img_path
'
:
join_path
(
img_buket_path
,
table
[
'image_path'
]),
"
table_latex
"
:
table
.
get
(
"
text
"
),
'
table_latex
'
:
table
.
get
(
'
text
'
),
"
table_title
"
:
""
,
'
table_title
'
:
''
,
"
table_caption
"
:
""
,
'
table_caption
'
:
''
,
"
table_quality
"
:
table
.
get
(
"
quality
"
),
'
table_quality
'
:
table
.
get
(
'
quality
'
),
}
}
page_lst
.
append
(
content_node
)
# TODO 图片顺序
page_lst
.
append
(
content_node
)
# TODO 图片顺序
else
:
else
:
for
block
in
para_blocks
:
for
block
in
para_blocks
:
item
=
block
[
"
paras
"
]
item
=
block
[
'
paras
'
]
for
_
,
p
in
item
.
items
():
for
_
,
p
in
item
.
items
():
font_type
=
p
[
'para_font_type'
]
# 对于文本来说,要么是普通文本,要么是个行间公式
font_type
=
p
[
'para_font_type'
]
# 对于文本来说,要么是普通文本,要么是个行间公式
if
font_type
==
TYPE_INTERLINE_EQUATION
:
if
font_type
==
TYPE_INTERLINE_EQUATION
:
content_node
=
{
content_node
=
{
'type'
:
'equation'
,
'latex'
:
p
[
'para_text'
]}
"type"
:
"equation"
,
"latex"
:
p
[
"para_text"
]
}
page_lst
.
append
(
content_node
)
page_lst
.
append
(
content_node
)
else
:
else
:
para_text
=
p
[
"
para_text
"
]
para_text
=
p
[
'
para_text
'
]
is_title
=
p
[
"
is_para_title
"
]
is_title
=
p
[
'
is_para_title
'
]
title_level
=
p
[
'para_title_level'
]
title_level
=
p
[
'para_title_level'
]
if
is_title
:
if
is_title
:
content_node
=
{
content_node
=
{
"
type
"
:
f
"
h
{
title_level
}
"
,
'
type
'
:
f
'
h
{
title_level
}
'
,
"
text
"
:
para_text
'
text
'
:
para_text
,
}
}
page_lst
.
append
(
content_node
)
page_lst
.
append
(
content_node
)
else
:
else
:
content_node
=
{
content_node
=
{
'type'
:
'text'
,
'text'
:
para_text
}
"type"
:
"text"
,
"text"
:
para_text
}
page_lst
.
append
(
content_node
)
page_lst
.
append
(
content_node
)
content_lst
.
extend
(
page_lst
)
content_lst
.
extend
(
page_lst
)
"""插入图片"""
"""插入图片"""
for
img
in
all_page_images
:
for
img
in
all_page_images
:
insert_img_or_table
(
"
image
"
,
img
,
pymu_raw_blocks
,
content_lst
)
insert_img_or_table
(
'
image
'
,
img
,
pymu_raw_blocks
,
content_lst
)
"""插入表格"""
"""插入表格"""
for
table
in
all_page_tables
:
for
table
in
all_page_tables
:
insert_img_or_table
(
"
table
"
,
table
,
pymu_raw_blocks
,
content_lst
)
insert_img_or_table
(
'
table
'
,
table
,
pymu_raw_blocks
,
content_lst
)
# end for
# end for
return
content_lst
return
content_lst
...
@@ -313,13 +342,17 @@ def insert_img_or_table(type, element, pymu_raw_blocks, content_lst):
...
@@ -313,13 +342,17 @@ def insert_img_or_table(type, element, pymu_raw_blocks, content_lst):
# 先看在哪个block内
# 先看在哪个block内
for
block
in
pymu_raw_blocks
:
for
block
in
pymu_raw_blocks
:
bbox
=
block
[
'bbox'
]
bbox
=
block
[
'bbox'
]
if
bbox
[
0
]
-
1
<=
element_bbox
[
0
]
<
bbox
[
2
]
+
1
and
bbox
[
1
]
-
1
<=
element_bbox
[
1
]
<
bbox
[
if
(
3
]
+
1
:
# 确定在这个大的block内,然后进入逐行比较距离
bbox
[
0
]
-
1
<=
element_bbox
[
0
]
<
bbox
[
2
]
+
1
for
l
in
block
[
'lines'
]:
and
bbox
[
1
]
-
1
<=
element_bbox
[
1
]
<
bbox
[
3
]
+
1
):
# 确定在这个大的block内,然后进入逐行比较距离
for
l
in
block
[
'lines'
]:
# noqa: E741
line_box
=
l
[
'bbox'
]
line_box
=
l
[
'bbox'
]
if
line_box
[
0
]
-
1
<=
element_bbox
[
0
]
<
line_box
[
2
]
+
1
and
line_box
[
1
]
-
1
<=
element_bbox
[
1
]
<
line_box
[
if
(
3
]
+
1
:
# 在line内的,插入line前面
line_box
[
0
]
-
1
<=
element_bbox
[
0
]
<
line_box
[
2
]
+
1
line_txt
=
""
.
join
([
s
[
'text'
]
for
s
in
l
[
'spans'
]])
and
line_box
[
1
]
-
1
<=
element_bbox
[
1
]
<
line_box
[
3
]
+
1
):
# 在line内的,插入line前面
line_txt
=
''
.
join
([
s
[
'text'
]
for
s
in
l
[
'spans'
]])
__insert_before_para
(
line_txt
,
type
,
element
,
content_lst
)
__insert_before_para
(
line_txt
,
type
,
element
,
content_lst
)
break
break
break
break
...
@@ -327,14 +360,17 @@ def insert_img_or_table(type, element, pymu_raw_blocks, content_lst):
...
@@ -327,14 +360,17 @@ def insert_img_or_table(type, element, pymu_raw_blocks, content_lst):
# 找到图片x0,y0与line的x0,y0最近的line
# 找到图片x0,y0与line的x0,y0最近的line
min_distance
=
100000
min_distance
=
100000
min_line
=
None
min_line
=
None
for
l
in
block
[
'lines'
]:
for
l
in
block
[
'lines'
]:
# noqa: E741
line_box
=
l
[
'bbox'
]
line_box
=
l
[
'bbox'
]
distance
=
math
.
sqrt
((
line_box
[
0
]
-
element_bbox
[
0
])
**
2
+
(
line_box
[
1
]
-
element_bbox
[
1
])
**
2
)
distance
=
math
.
sqrt
(
(
line_box
[
0
]
-
element_bbox
[
0
])
**
2
+
(
line_box
[
1
]
-
element_bbox
[
1
])
**
2
)
if
distance
<
min_distance
:
if
distance
<
min_distance
:
min_distance
=
distance
min_distance
=
distance
min_line
=
l
min_line
=
l
if
min_line
:
if
min_line
:
line_txt
=
""
.
join
([
s
[
'text'
]
for
s
in
min_line
[
'spans'
]])
line_txt
=
''
.
join
([
s
[
'text'
]
for
s
in
min_line
[
'spans'
]])
img_h
=
element_bbox
[
3
]
-
element_bbox
[
1
]
img_h
=
element_bbox
[
3
]
-
element_bbox
[
1
]
if
min_distance
<
img_h
:
# 文字在图片前面
if
min_distance
<
img_h
:
# 文字在图片前面
__insert_after_para
(
line_txt
,
type
,
element
,
content_lst
)
__insert_after_para
(
line_txt
,
type
,
element
,
content_lst
)
...
@@ -342,56 +378,61 @@ def insert_img_or_table(type, element, pymu_raw_blocks, content_lst):
...
@@ -342,56 +378,61 @@ def insert_img_or_table(type, element, pymu_raw_blocks, content_lst):
__insert_before_para
(
line_txt
,
type
,
element
,
content_lst
)
__insert_before_para
(
line_txt
,
type
,
element
,
content_lst
)
break
break
else
:
else
:
logger
.
error
(
f
"Can't find the location of image
{
element
.
get
(
'image_path'
)
}
in the markdown file #1"
)
logger
.
error
(
f
"Can't find the location of image
{
element
.
get
(
'image_path'
)
}
in the markdown file #1"
)
else
:
# 应当在两个block之间
else
:
# 应当在两个block之间
# 找到上方最近的block,如果上方没有就找大下方最近的block
# 找到上方最近的block,如果上方没有就找大下方最近的block
top_txt_block
=
find_top_nearest_text_bbox
(
pymu_raw_blocks
,
element_bbox
)
top_txt_block
=
find_top_nearest_text_bbox
(
pymu_raw_blocks
,
element_bbox
)
if
top_txt_block
:
if
top_txt_block
:
line_txt
=
""
.
join
([
s
[
'text'
]
for
s
in
top_txt_block
[
'lines'
][
-
1
][
'spans'
]])
line_txt
=
''
.
join
([
s
[
'text'
]
for
s
in
top_txt_block
[
'lines'
][
-
1
][
'spans'
]])
__insert_after_para
(
line_txt
,
type
,
element
,
content_lst
)
__insert_after_para
(
line_txt
,
type
,
element
,
content_lst
)
else
:
else
:
bottom_txt_block
=
find_bottom_nearest_text_bbox
(
pymu_raw_blocks
,
element_bbox
)
bottom_txt_block
=
find_bottom_nearest_text_bbox
(
pymu_raw_blocks
,
element_bbox
)
if
bottom_txt_block
:
if
bottom_txt_block
:
line_txt
=
""
.
join
([
s
[
'text'
]
for
s
in
bottom_txt_block
[
'lines'
][
0
][
'spans'
]])
line_txt
=
''
.
join
(
[
s
[
'text'
]
for
s
in
bottom_txt_block
[
'lines'
][
0
][
'spans'
]]
)
__insert_before_para
(
line_txt
,
type
,
element
,
content_lst
)
__insert_before_para
(
line_txt
,
type
,
element
,
content_lst
)
else
:
# TODO ,图片可能独占一列,这种情况上下是没有图片的
else
:
# TODO ,图片可能独占一列,这种情况上下是没有图片的
logger
.
error
(
f
"Can't find the location of image
{
element
.
get
(
'image_path'
)
}
in the markdown file #2"
)
logger
.
error
(
f
"Can't find the location of image
{
element
.
get
(
'image_path'
)
}
in the markdown file #2"
)
def
mk_mm_markdown
(
content_list
):
def
mk_mm_markdown
(
content_list
):
"""
"""基于同一格式的内容列表,构造markdown,含图片."""
基于同一格式的内容列表,构造markdown,含图片
"""
content_md
=
[]
content_md
=
[]
for
c
in
content_list
:
for
c
in
content_list
:
content_type
=
c
.
get
(
"
type
"
)
content_type
=
c
.
get
(
'
type
'
)
if
content_type
==
"
text
"
:
if
content_type
==
'
text
'
:
content_md
.
append
(
c
.
get
(
"
text
"
))
content_md
.
append
(
c
.
get
(
'
text
'
))
elif
content_type
==
"
equation
"
:
elif
content_type
==
'
equation
'
:
content
=
c
.
get
(
"
latex
"
)
content
=
c
.
get
(
'
latex
'
)
if
content
.
startswith
(
"
$$
"
)
and
content
.
endswith
(
"
$$
"
):
if
content
.
startswith
(
'
$$
'
)
and
content
.
endswith
(
'
$$
'
):
content_md
.
append
(
content
)
content_md
.
append
(
content
)
else
:
else
:
content_md
.
append
(
f
"
\n
$$
\n
{
c
.
get
(
'latex'
)
}
\n
$$
\n
"
)
content_md
.
append
(
f
"
\n
$$
\n
{
c
.
get
(
'latex'
)
}
\n
$$
\n
"
)
elif
content_type
in
UNI_FORMAT_TEXT_TYPE
:
elif
content_type
in
UNI_FORMAT_TEXT_TYPE
:
content_md
.
append
(
f
"
{
'#'
*
int
(
content_type
[
1
])
}
{
c
.
get
(
'text'
)
}
"
)
content_md
.
append
(
f
"
{
'#'
*
int
(
content_type
[
1
])
}
{
c
.
get
(
'text'
)
}
"
)
elif
content_type
==
"
image
"
:
elif
content_type
==
'
image
'
:
content_md
.
append
(
f
"
}
)"
)
content_md
.
append
(
f
"
}
)"
)
return
"
\n\n
"
.
join
(
content_md
)
return
'
\n\n
'
.
join
(
content_md
)
def
mk_nlp_markdown
(
content_list
):
def
mk_nlp_markdown
(
content_list
):
"""
"""基于同一格式的内容列表,构造markdown,不含图片."""
基于同一格式的内容列表,构造markdown,不含图片
"""
content_md
=
[]
content_md
=
[]
for
c
in
content_list
:
for
c
in
content_list
:
content_type
=
c
.
get
(
"
type
"
)
content_type
=
c
.
get
(
'
type
'
)
if
content_type
==
"
text
"
:
if
content_type
==
'
text
'
:
content_md
.
append
(
c
.
get
(
"
text
"
))
content_md
.
append
(
c
.
get
(
'
text
'
))
elif
content_type
==
"
equation
"
:
elif
content_type
==
'
equation
'
:
content_md
.
append
(
f
"$$
\n
{
c
.
get
(
'latex'
)
}
\n
$$"
)
content_md
.
append
(
f
"$$
\n
{
c
.
get
(
'latex'
)
}
\n
$$"
)
elif
content_type
==
"
table
"
:
elif
content_type
==
'
table
'
:
content_md
.
append
(
f
"$$$
\n
{
c
.
get
(
'table_latex'
)
}
\n
$$$"
)
content_md
.
append
(
f
"$$$
\n
{
c
.
get
(
'table_latex'
)
}
\n
$$$"
)
elif
content_type
in
UNI_FORMAT_TEXT_TYPE
:
elif
content_type
in
UNI_FORMAT_TEXT_TYPE
:
content_md
.
append
(
f
"
{
'#'
*
int
(
content_type
[
1
])
}
{
c
.
get
(
'text'
)
}
"
)
content_md
.
append
(
f
"
{
'#'
*
int
(
content_type
[
1
])
}
{
c
.
get
(
'text'
)
}
"
)
return
"
\n\n
"
.
join
(
content_md
)
return
'
\n\n
'
.
join
(
content_md
)
\ No newline at end of file
magic_pdf/dict2md/ocr_mkcontent.py
View file @
ea2f8ea0
...
@@ -2,21 +2,20 @@ import re
...
@@ -2,21 +2,20 @@ import re
from
loguru
import
logger
from
loguru
import
logger
from
magic_pdf.config.make_content_config
import
DropMode
,
MakeMode
from
magic_pdf.config.ocr_content_type
import
BlockType
,
ContentType
from
magic_pdf.libs.commons
import
join_path
from
magic_pdf.libs.commons
import
join_path
from
magic_pdf.libs.language
import
detect_lang
from
magic_pdf.libs.language
import
detect_lang
from
magic_pdf.libs.MakeContentConfig
import
DropMode
,
MakeMode
from
magic_pdf.libs.markdown_utils
import
ocr_escape_special_markdown_char
from
magic_pdf.libs.markdown_utils
import
ocr_escape_special_markdown_char
from
magic_pdf.libs.ocr_content_type
import
BlockType
,
ContentType
from
magic_pdf.para.para_split_v3
import
ListLineTag
from
magic_pdf.para.para_split_v3
import
ListLineTag
def
__is_hyphen_at_line_end
(
line
):
def
__is_hyphen_at_line_end
(
line
):
"""
"""Check if a line ends with one or more letters followed by a hyphen.
Check if a line ends with one or more letters followed by a hyphen.
Args:
Args:
line (str): The line of text to check.
line (str): The line of text to check.
Returns:
Returns:
bool: True if the line ends with one or more letters followed by a hyphen, False otherwise.
bool: True if the line ends with one or more letters followed by a hyphen, False otherwise.
"""
"""
...
@@ -163,7 +162,7 @@ def merge_para_with_text(para_block):
...
@@ -163,7 +162,7 @@ def merge_para_with_text(para_block):
if
span_type
in
[
ContentType
.
Text
,
ContentType
.
InterlineEquation
]:
if
span_type
in
[
ContentType
.
Text
,
ContentType
.
InterlineEquation
]:
para_text
+=
content
# 中文/日语/韩文语境下,content间不需要空格分隔
para_text
+=
content
# 中文/日语/韩文语境下,content间不需要空格分隔
elif
span_type
==
ContentType
.
InlineEquation
:
elif
span_type
==
ContentType
.
InlineEquation
:
para_text
+=
f
"
{
content
}
"
para_text
+=
f
'
{
content
}
'
else
:
else
:
if
span_type
in
[
ContentType
.
Text
,
ContentType
.
InlineEquation
]:
if
span_type
in
[
ContentType
.
Text
,
ContentType
.
InlineEquation
]:
# 如果span是line的最后一个且末尾带有-连字符,那么末尾不应该加空格,同时应该把-删除
# 如果span是line的最后一个且末尾带有-连字符,那么末尾不应该加空格,同时应该把-删除
...
@@ -172,7 +171,7 @@ def merge_para_with_text(para_block):
...
@@ -172,7 +171,7 @@ def merge_para_with_text(para_block):
elif
len
(
content
)
==
1
and
content
not
in
[
'A'
,
'I'
,
'a'
,
'i'
]
and
not
content
.
isdigit
():
elif
len
(
content
)
==
1
and
content
not
in
[
'A'
,
'I'
,
'a'
,
'i'
]
and
not
content
.
isdigit
():
para_text
+=
content
para_text
+=
content
else
:
# 西方文本语境下 content间需要空格分隔
else
:
# 西方文本语境下 content间需要空格分隔
para_text
+=
f
"
{
content
}
"
para_text
+=
f
'
{
content
}
'
elif
span_type
==
ContentType
.
InterlineEquation
:
elif
span_type
==
ContentType
.
InterlineEquation
:
para_text
+=
content
para_text
+=
content
else
:
else
:
...
...
magic_pdf/filter/pdf_meta_scan.py
View file @
ea2f8ea0
"""
"""输入: s3路径,每行一个 输出: pdf文件元信息,包括每一页上的所有图片的长宽高,bbox位置."""
输入: s3路径,每行一个
输出: pdf文件元信息,包括每一页上的所有图片的长宽高,bbox位置
"""
import
sys
import
sys
import
click
from
collections
import
Counter
from
magic_pdf.libs.commons
import
read_file
,
mymax
,
get_top_percent_list
import
click
from
magic_pdf.libs.commons
import
fitz
from
loguru
import
logger
from
loguru
import
logger
from
collections
import
Counter
from
magic_pdf.libs.drop_reason
import
DropReason
from
magic_pdf.config.drop_reason
import
DropReason
from
magic_pdf.libs.commons
import
fitz
,
get_top_percent_list
,
mymax
,
read_file
from
magic_pdf.libs.language
import
detect_lang
from
magic_pdf.libs.language
import
detect_lang
from
magic_pdf.libs.pdf_check
import
detect_invalid_chars
from
magic_pdf.libs.pdf_check
import
detect_invalid_chars
...
@@ -19,8 +16,10 @@ junk_limit_min = 10
...
@@ -19,8 +16,10 @@ junk_limit_min = 10
def
calculate_max_image_area_per_page
(
result
:
list
,
page_width_pts
,
page_height_pts
):
def
calculate_max_image_area_per_page
(
result
:
list
,
page_width_pts
,
page_height_pts
):
max_image_area_per_page
=
[
mymax
([(
x1
-
x0
)
*
(
y1
-
y0
)
for
x0
,
y0
,
x1
,
y1
,
_
in
page_img_sz
])
for
page_img_sz
in
max_image_area_per_page
=
[
result
]
mymax
([(
x1
-
x0
)
*
(
y1
-
y0
)
for
x0
,
y0
,
x1
,
y1
,
_
in
page_img_sz
])
for
page_img_sz
in
result
]
page_area
=
int
(
page_width_pts
)
*
int
(
page_height_pts
)
page_area
=
int
(
page_width_pts
)
*
int
(
page_height_pts
)
max_image_area_per_page
=
[
area
/
page_area
for
area
in
max_image_area_per_page
]
max_image_area_per_page
=
[
area
/
page_area
for
area
in
max_image_area_per_page
]
max_image_area_per_page
=
[
area
for
area
in
max_image_area_per_page
if
area
>
0.6
]
max_image_area_per_page
=
[
area
for
area
in
max_image_area_per_page
if
area
>
0.6
]
...
@@ -32,8 +31,10 @@ def process_image(page, junk_img_bojids=[]):
...
@@ -32,8 +31,10 @@ def process_image(page, junk_img_bojids=[]):
items
=
page
.
get_images
()
items
=
page
.
get_images
()
dedup
=
set
()
dedup
=
set
()
for
img
in
items
:
for
img
in
items
:
# 这里返回的是图片在page上的实际展示的大小。返回一个数组,每个元素第一部分是
# 这里返回的是图片在page上的实际展示的大小。返回一个数组,每个元素第一部分是
img_bojid
=
img
[
0
]
# 在pdf文件中是全局唯一的,如果这个图反复出现在pdf里那么就可能是垃圾信息,例如水印、页眉页脚等
img_bojid
=
img
[
0
]
# 在pdf文件中是全局唯一的,如果这个图反复出现在pdf里那么就可能是垃圾信息,例如水印、页眉页脚等
if
img_bojid
in
junk_img_bojids
:
# 如果是垃圾图像,就跳过
if
img_bojid
in
junk_img_bojids
:
# 如果是垃圾图像,就跳过
continue
continue
recs
=
page
.
get_image_rects
(
img
,
transform
=
True
)
recs
=
page
.
get_image_rects
(
img
,
transform
=
True
)
...
@@ -42,9 +43,17 @@ def process_image(page, junk_img_bojids=[]):
...
@@ -42,9 +43,17 @@ def process_image(page, junk_img_bojids=[]):
x0
,
y0
,
x1
,
y1
=
map
(
int
,
rec
)
x0
,
y0
,
x1
,
y1
=
map
(
int
,
rec
)
width
=
x1
-
x0
width
=
x1
-
x0
height
=
y1
-
y0
height
=
y1
-
y0
if
(
x0
,
y0
,
x1
,
y1
,
img_bojid
)
in
dedup
:
# 这里面会出现一些重复的bbox,无需重复出现,需要去掉
if
(
x0
,
y0
,
x1
,
y1
,
img_bojid
,
)
in
dedup
:
# 这里面会出现一些重复的bbox,无需重复出现,需要去掉
continue
continue
if
not
all
([
width
,
height
]):
# 长和宽任何一个都不能是0,否则这个图片不可见,没有实际意义
if
not
all
(
[
width
,
height
]
):
# 长和宽任何一个都不能是0,否则这个图片不可见,没有实际意义
continue
continue
dedup
.
add
((
x0
,
y0
,
x1
,
y1
,
img_bojid
))
dedup
.
add
((
x0
,
y0
,
x1
,
y1
,
img_bojid
))
page_result
.
append
([
x0
,
y0
,
x1
,
y1
,
img_bojid
])
page_result
.
append
([
x0
,
y0
,
x1
,
y1
,
img_bojid
])
...
@@ -52,29 +61,33 @@ def process_image(page, junk_img_bojids=[]):
...
@@ -52,29 +61,33 @@ def process_image(page, junk_img_bojids=[]):
def
get_image_info
(
doc
:
fitz
.
Document
,
page_width_pts
,
page_height_pts
)
->
list
:
def
get_image_info
(
doc
:
fitz
.
Document
,
page_width_pts
,
page_height_pts
)
->
list
:
"""
"""
返回每个页面里的图片的四元组,每个页面多个图片。
返回每个页面里的图片的四元组,每个页面多个图片。
:param doc:
:param doc:
:return:
:return:
"""
"""
# 使用 Counter 计数 img_bojid 的出现次数
#
使用 Counter 计数 img_bojid 的出现次数
img_bojid_counter
=
Counter
(
img
[
0
]
for
page
in
doc
for
img
in
page
.
get_images
())
img_bojid_counter
=
Counter
(
img
[
0
]
for
page
in
doc
for
img
in
page
.
get_images
())
# 找出出现次数超过 len(doc) 半数的 img_bojid
#
找出出现次数超过 len(doc) 半数的 img_bojid
junk_limit
=
max
(
len
(
doc
)
*
0.5
,
junk_limit_min
)
# 对一些页数比较少的进行豁免
junk_limit
=
max
(
len
(
doc
)
*
0.5
,
junk_limit_min
)
# 对一些页数比较少的进行豁免
junk_img_bojids
=
[
img_bojid
for
img_bojid
,
count
in
img_bojid_counter
.
items
()
if
count
>=
junk_limit
]
junk_img_bojids
=
[
img_bojid
#todo 加个判断,用前十页就行,这些垃圾图片需要满足两个条件,不止出现的次数要足够多,而且图片占书页面积的比例要足够大,且图与图大小都差不多
for
img_bojid
,
count
in
img_bojid_counter
.
items
()
#有两种扫描版,一种文字版,这里可能会有误判
if
count
>=
junk_limit
#扫描版1:每页都有所有扫描页图片,特点是图占比大,每页展示1张
]
#扫描版2,每页存储的扫描页图片数量递增,特点是图占比大,每页展示1张,需要清空junklist跑前50页图片信息用于分类判断
#文字版1.每页存储所有图片,特点是图片占页面比例不大,每页展示可能为0也可能不止1张 这种pdf需要拿前10页抽样检测img大小和个数,如果符合需要清空junklist
# todo 加个判断,用前十页就行,这些垃圾图片需要满足两个条件,不止出现的次数要足够多,而且图片占书页面积的比例要足够大,且图与图大小都差不多
# 有两种扫描版,一种文字版,这里可能会有误判
# 扫描版1:每页都有所有扫描页图片,特点是图占比大,每页展示1张
# 扫描版2,每页存储的扫描页图片数量递增,特点是图占比大,每页展示1张,需要清空junklist跑前50页图片信息用于分类判断
# 文 字版1.每页存储所有图片,特点是图片占页面比例不大,每页展示可能为0也可能不止1张 这种pdf需要拿前10页抽样检测img大小和个数,如果符合需要清空junklist
imgs_len_list
=
[
len
(
page
.
get_images
())
for
page
in
doc
]
imgs_len_list
=
[
len
(
page
.
get_images
())
for
page
in
doc
]
special_limit_pages
=
10
special_limit_pages
=
10
# 统一用前十页结果做判断
#
统一用前十页结果做判断
result
=
[]
result
=
[]
break_loop
=
False
break_loop
=
False
for
i
,
page
in
enumerate
(
doc
):
for
i
,
page
in
enumerate
(
doc
):
...
@@ -82,12 +95,18 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
...
@@ -82,12 +95,18 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
break
break
if
i
>=
special_limit_pages
:
if
i
>=
special_limit_pages
:
break
break
page_result
=
process_image
(
page
)
# 这里不传junk_img_bojids,拿前十页所有图片信息用于后续分析
page_result
=
process_image
(
page
)
# 这里不传junk_img_bojids,拿前十页所有图片信息用于后续分析
result
.
append
(
page_result
)
result
.
append
(
page_result
)
for
item
in
result
:
for
item
in
result
:
if
not
any
(
item
):
# 如果任何一页没有图片,说明是个文字版,需要判断是否为特殊文字版
if
not
any
(
if
max
(
imgs_len_list
)
==
min
(
imgs_len_list
)
and
max
(
item
imgs_len_list
)
>=
junk_limit_min
:
# 如果是特殊文字版,就把junklist置空并break
):
# 如果任何一页没有图片,说明是个文字版,需要判断是否为特殊文字版
if
(
max
(
imgs_len_list
)
==
min
(
imgs_len_list
)
and
max
(
imgs_len_list
)
>=
junk_limit_min
):
# 如果是特殊文字版,就把junklist置空并break
junk_img_bojids
=
[]
junk_img_bojids
=
[]
else
:
# 不是特殊文字版,是个普通文字版,但是存在垃圾图片,不置空junklist
else
:
# 不是特殊文字版,是个普通文字版,但是存在垃圾图片,不置空junklist
pass
pass
...
@@ -98,20 +117,23 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
...
@@ -98,20 +117,23 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
top_eighty_percent
=
get_top_percent_list
(
imgs_len_list
,
0.8
)
top_eighty_percent
=
get_top_percent_list
(
imgs_len_list
,
0.8
)
# 检查前80%的元素是否都相等
# 检查前80%的元素是否都相等
if
len
(
set
(
top_eighty_percent
))
==
1
and
max
(
imgs_len_list
)
>=
junk_limit_min
:
if
len
(
set
(
top_eighty_percent
))
==
1
and
max
(
imgs_len_list
)
>=
junk_limit_min
:
# # 如果前10页跑完都有图,根据每页图片数量是否相等判断是否需要清除junklist
# # 如果前10页跑完都有图,根据每页图片数量是否相等判断是否需要清除junklist
# if max(imgs_len_list) == min(imgs_len_list) and max(imgs_len_list) >= junk_limit_min:
# if max(imgs_len_list) == min(imgs_len_list) and max(imgs_len_list) >= junk_limit_min:
#前10页都有图,且每页数量一致,需要检测图片大小占页面的比例判断是否需要清除junklist
# 前10页都有图,且每页数量一致,需要检测图片大小占页面的比例判断是否需要清除junklist
max_image_area_per_page
=
calculate_max_image_area_per_page
(
result
,
page_width_pts
,
page_height_pts
)
max_image_area_per_page
=
calculate_max_image_area_per_page
(
if
len
(
max_image_area_per_page
)
<
0.8
*
special_limit_pages
:
# 前10页不全是大图,说明可能是个文字版pdf,把垃圾图片list置空
result
,
page_width_pts
,
page_height_pts
)
if
(
len
(
max_image_area_per_page
)
<
0.8
*
special_limit_pages
):
# 前10页不全是大图,说明可能是个文字版pdf,把垃圾图片list置空
junk_img_bojids
=
[]
junk_img_bojids
=
[]
else
:
# 前10页都有图,而且80%都是大图,且每页图片数量一致并都很多,说明是扫描版1,不需要清空junklist
else
:
# 前10页都有图,而且80%都是大图,且每页图片数量一致并都很多,说明是扫描版1,不需要清空junklist
pass
pass
else
:
# 每页图片数量不一致,需要清掉junklist全量跑前50页图片
else
:
# 每页图片数量不一致,需要清掉junklist全量跑前50页图片
junk_img_bojids
=
[]
junk_img_bojids
=
[]
#正式进入取前50页图片的信息流程
#
正式进入取前50页图片的信息流程
result
=
[]
result
=
[]
for
i
,
page
in
enumerate
(
doc
):
for
i
,
page
in
enumerate
(
doc
):
if
i
>=
scan_max_page
:
if
i
>=
scan_max_page
:
...
@@ -126,7 +148,7 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
...
@@ -126,7 +148,7 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
def
get_pdf_page_size_pts
(
doc
:
fitz
.
Document
):
def
get_pdf_page_size_pts
(
doc
:
fitz
.
Document
):
page_cnt
=
len
(
doc
)
page_cnt
=
len
(
doc
)
l
:
int
=
min
(
page_cnt
,
50
)
l
:
int
=
min
(
page_cnt
,
50
)
#把所有宽度和高度塞到两个list 分别取中位数(中间遇到了个在纵页里塞横页的pdf,导致宽高互换了)
#
把所有宽度和高度塞到两个list 分别取中位数(中间遇到了个在纵页里塞横页的pdf,导致宽高互换了)
page_width_list
=
[]
page_width_list
=
[]
page_height_list
=
[]
page_height_list
=
[]
for
i
in
range
(
l
):
for
i
in
range
(
l
):
...
@@ -152,8 +174,8 @@ def get_pdf_textlen_per_page(doc: fitz.Document):
...
@@ -152,8 +174,8 @@ def get_pdf_textlen_per_page(doc: fitz.Document):
# 拿所有text的blocks
# 拿所有text的blocks
# text_block = page.get_text("words")
# text_block = page.get_text("words")
# text_block_len = sum([len(t[4]) for t in text_block])
# text_block_len = sum([len(t[4]) for t in text_block])
#拿所有text的str
#
拿所有text的str
text_block
=
page
.
get_text
(
"
text
"
)
text_block
=
page
.
get_text
(
'
text
'
)
text_block_len
=
len
(
text_block
)
text_block_len
=
len
(
text_block
)
# logger.info(f"page {page.number} text_block_len: {text_block_len}")
# logger.info(f"page {page.number} text_block_len: {text_block_len}")
text_len_lst
.
append
(
text_block_len
)
text_len_lst
.
append
(
text_block_len
)
...
@@ -162,15 +184,13 @@ def get_pdf_textlen_per_page(doc: fitz.Document):
...
@@ -162,15 +184,13 @@ def get_pdf_textlen_per_page(doc: fitz.Document):
def
get_pdf_text_layout_per_page
(
doc
:
fitz
.
Document
):
def
get_pdf_text_layout_per_page
(
doc
:
fitz
.
Document
):
"""
"""根据PDF文档的每一页文本布局,判断该页的文本布局是横向、纵向还是未知。
根据PDF文档的每一页文本布局,判断该页的文本布局是横向、纵向还是未知。
Args:
Args:
doc (fitz.Document): PDF文档对象。
doc (fitz.Document): PDF文档对象。
Returns:
Returns:
List[str]: 每一页的文本布局(横向、纵向、未知)。
List[str]: 每一页的文本布局(横向、纵向、未知)。
"""
"""
text_layout_list
=
[]
text_layout_list
=
[]
...
@@ -180,11 +200,11 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
...
@@ -180,11 +200,11 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
# 创建每一页的纵向和横向的文本行数计数器
# 创建每一页的纵向和横向的文本行数计数器
vertical_count
=
0
vertical_count
=
0
horizontal_count
=
0
horizontal_count
=
0
text_dict
=
page
.
get_text
(
"
dict
"
)
text_dict
=
page
.
get_text
(
'
dict
'
)
if
"
blocks
"
in
text_dict
:
if
'
blocks
'
in
text_dict
:
for
block
in
text_dict
[
"
blocks
"
]:
for
block
in
text_dict
[
'
blocks
'
]:
if
'lines'
in
block
:
if
'lines'
in
block
:
for
line
in
block
[
"
lines
"
]:
for
line
in
block
[
'
lines
'
]:
# 获取line的bbox顶点坐标
# 获取line的bbox顶点坐标
x0
,
y0
,
x1
,
y1
=
line
[
'bbox'
]
x0
,
y0
,
x1
,
y1
=
line
[
'bbox'
]
# 计算bbox的宽高
# 计算bbox的宽高
...
@@ -199,8 +219,12 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
...
@@ -199,8 +219,12 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
if
len
(
font_sizes
)
>
0
:
if
len
(
font_sizes
)
>
0
:
average_font_size
=
sum
(
font_sizes
)
/
len
(
font_sizes
)
average_font_size
=
sum
(
font_sizes
)
/
len
(
font_sizes
)
else
:
else
:
average_font_size
=
10
# 有的line拿不到font_size,先定一个阈值100
average_font_size
=
(
if
area
<=
average_font_size
**
2
:
# 判断bbox的面积是否小于平均字体大小的平方,单字无法计算是横向还是纵向
10
# 有的line拿不到font_size,先定一个阈值100
)
if
(
area
<=
average_font_size
**
2
):
# 判断bbox的面积是否小于平均字体大小的平方,单字无法计算是横向还是纵向
continue
continue
else
:
else
:
if
'wmode'
in
line
:
# 通过wmode判断文本方向
if
'wmode'
in
line
:
# 通过wmode判断文本方向
...
@@ -228,22 +252,22 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
...
@@ -228,22 +252,22 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
# print(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}")
# print(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}")
# 判断每一页的文本布局
# 判断每一页的文本布局
if
vertical_count
==
0
and
horizontal_count
==
0
:
# 该页没有文本,无法判断
if
vertical_count
==
0
and
horizontal_count
==
0
:
# 该页没有文本,无法判断
text_layout_list
.
append
(
"
unknow
"
)
text_layout_list
.
append
(
'
unknow
'
)
continue
continue
else
:
else
:
if
vertical_count
>
horizontal_count
:
# 该页的文本纵向行数大于横向的
if
vertical_count
>
horizontal_count
:
# 该页的文本纵向行数大于横向的
text_layout_list
.
append
(
"
vertical
"
)
text_layout_list
.
append
(
'
vertical
'
)
else
:
# 该页的文本横向行数大于纵向的
else
:
# 该页的文本横向行数大于纵向的
text_layout_list
.
append
(
"
horizontal
"
)
text_layout_list
.
append
(
'
horizontal
'
)
# logger.info(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}")
# logger.info(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}")
return
text_layout_list
return
text_layout_list
'''
定义一个自定义异常用来抛出单页svg太多的pdf
'''
"""
定义一个自定义异常用来抛出单页svg太多的pdf
"""
class
PageSvgsTooManyError
(
Exception
):
class
PageSvgsTooManyError
(
Exception
):
def
__init__
(
self
,
message
=
"
Page SVGs are too many
"
):
def
__init__
(
self
,
message
=
'
Page SVGs are too many
'
):
self
.
message
=
message
self
.
message
=
message
super
().
__init__
(
self
.
message
)
super
().
__init__
(
self
.
message
)
...
@@ -285,7 +309,7 @@ def get_language(doc: fitz.Document):
...
@@ -285,7 +309,7 @@ def get_language(doc: fitz.Document):
if
page_id
>=
scan_max_page
:
if
page_id
>=
scan_max_page
:
break
break
# 拿所有text的str
# 拿所有text的str
text_block
=
page
.
get_text
(
"
text
"
)
text_block
=
page
.
get_text
(
'
text
'
)
page_language
=
detect_lang
(
text_block
)
page_language
=
detect_lang
(
text_block
)
language_lst
.
append
(
page_language
)
language_lst
.
append
(
page_language
)
...
@@ -299,9 +323,7 @@ def get_language(doc: fitz.Document):
...
@@ -299,9 +323,7 @@ def get_language(doc: fitz.Document):
def
check_invalid_chars
(
pdf_bytes
):
def
check_invalid_chars
(
pdf_bytes
):
"""
"""乱码检测."""
乱码检测
"""
return
detect_invalid_chars
(
pdf_bytes
)
return
detect_invalid_chars
(
pdf_bytes
)
...
@@ -311,13 +333,13 @@ def pdf_meta_scan(pdf_bytes: bytes):
...
@@ -311,13 +333,13 @@ def pdf_meta_scan(pdf_bytes: bytes):
:param pdf_bytes: pdf文件的二进制数据
:param pdf_bytes: pdf文件的二进制数据
几个维度来评价:是否加密,是否需要密码,纸张大小,总页数,是否文字可提取
几个维度来评价:是否加密,是否需要密码,纸张大小,总页数,是否文字可提取
"""
"""
doc
=
fitz
.
open
(
"
pdf
"
,
pdf_bytes
)
doc
=
fitz
.
open
(
'
pdf
'
,
pdf_bytes
)
is_needs_password
=
doc
.
needs_pass
is_needs_password
=
doc
.
needs_pass
is_encrypted
=
doc
.
is_encrypted
is_encrypted
=
doc
.
is_encrypted
total_page
=
len
(
doc
)
total_page
=
len
(
doc
)
if
total_page
==
0
:
if
total_page
==
0
:
logger
.
warning
(
f
"
drop this pdf, drop_reason:
{
DropReason
.
EMPTY_PDF
}
"
)
logger
.
warning
(
f
'
drop this pdf, drop_reason:
{
DropReason
.
EMPTY_PDF
}
'
)
result
=
{
"
_need_drop
"
:
True
,
"
_drop_reason
"
:
DropReason
.
EMPTY_PDF
}
result
=
{
'
_need_drop
'
:
True
,
'
_drop_reason
'
:
DropReason
.
EMPTY_PDF
}
return
result
return
result
else
:
else
:
page_width_pts
,
page_height_pts
=
get_pdf_page_size_pts
(
doc
)
page_width_pts
,
page_height_pts
=
get_pdf_page_size_pts
(
doc
)
...
@@ -328,7 +350,9 @@ def pdf_meta_scan(pdf_bytes: bytes):
...
@@ -328,7 +350,9 @@ def pdf_meta_scan(pdf_bytes: bytes):
imgs_per_page
=
get_imgs_per_page
(
doc
)
imgs_per_page
=
get_imgs_per_page
(
doc
)
# logger.info(f"imgs_per_page: {imgs_per_page}")
# logger.info(f"imgs_per_page: {imgs_per_page}")
image_info_per_page
,
junk_img_bojids
=
get_image_info
(
doc
,
page_width_pts
,
page_height_pts
)
image_info_per_page
,
junk_img_bojids
=
get_image_info
(
doc
,
page_width_pts
,
page_height_pts
)
# logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}")
# logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}")
text_len_per_page
=
get_pdf_textlen_per_page
(
doc
)
text_len_per_page
=
get_pdf_textlen_per_page
(
doc
)
# logger.info(f"text_len_per_page: {text_len_per_page}")
# logger.info(f"text_len_per_page: {text_len_per_page}")
...
@@ -341,20 +365,20 @@ def pdf_meta_scan(pdf_bytes: bytes):
...
@@ -341,20 +365,20 @@ def pdf_meta_scan(pdf_bytes: bytes):
# 最后输出一条json
# 最后输出一条json
res
=
{
res
=
{
"
is_needs_password
"
:
is_needs_password
,
'
is_needs_password
'
:
is_needs_password
,
"
is_encrypted
"
:
is_encrypted
,
'
is_encrypted
'
:
is_encrypted
,
"
total_page
"
:
total_page
,
'
total_page
'
:
total_page
,
"
page_width_pts
"
:
int
(
page_width_pts
),
'
page_width_pts
'
:
int
(
page_width_pts
),
"
page_height_pts
"
:
int
(
page_height_pts
),
'
page_height_pts
'
:
int
(
page_height_pts
),
"
image_info_per_page
"
:
image_info_per_page
,
'
image_info_per_page
'
:
image_info_per_page
,
"
text_len_per_page
"
:
text_len_per_page
,
'
text_len_per_page
'
:
text_len_per_page
,
"
text_layout_per_page
"
:
text_layout_per_page
,
'
text_layout_per_page
'
:
text_layout_per_page
,
"
text_language
"
:
text_language
,
'
text_language
'
:
text_language
,
# "svgs_per_page": svgs_per_page,
# "svgs_per_page": svgs_per_page,
"
imgs_per_page
"
:
imgs_per_page
,
# 增加每页img数量list
'
imgs_per_page
'
:
imgs_per_page
,
# 增加每页img数量list
"
junk_img_bojids
"
:
junk_img_bojids
,
# 增加垃圾图片的bojid list
'
junk_img_bojids
'
:
junk_img_bojids
,
# 增加垃圾图片的bojid list
"
invalid_chars
"
:
invalid_chars
,
'
invalid_chars
'
:
invalid_chars
,
"
metadata
"
:
doc
.
metadata
'
metadata
'
:
doc
.
metadata
,
}
}
# logger.info(json.dumps(res, ensure_ascii=False))
# logger.info(json.dumps(res, ensure_ascii=False))
return
res
return
res
...
@@ -364,14 +388,12 @@ def pdf_meta_scan(pdf_bytes: bytes):
...
@@ -364,14 +388,12 @@ def pdf_meta_scan(pdf_bytes: bytes):
@
click
.
option
(
'--s3-pdf-path'
,
help
=
's3上pdf文件的路径'
)
@
click
.
option
(
'--s3-pdf-path'
,
help
=
's3上pdf文件的路径'
)
@
click
.
option
(
'--s3-profile'
,
help
=
's3上的profile'
)
@
click
.
option
(
'--s3-profile'
,
help
=
's3上的profile'
)
def
main
(
s3_pdf_path
:
str
,
s3_profile
:
str
):
def
main
(
s3_pdf_path
:
str
,
s3_profile
:
str
):
"""
""""""
"""
try
:
try
:
file_content
=
read_file
(
s3_pdf_path
,
s3_profile
)
file_content
=
read_file
(
s3_pdf_path
,
s3_profile
)
pdf_meta_scan
(
file_content
)
pdf_meta_scan
(
file_content
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"
ERROR:
{
s3_pdf_path
}
,
{
e
}
"
,
file
=
sys
.
stderr
)
print
(
f
'
ERROR:
{
s3_pdf_path
}
,
{
e
}
'
,
file
=
sys
.
stderr
)
logger
.
exception
(
e
)
logger
.
exception
(
e
)
...
@@ -381,7 +403,7 @@ if __name__ == '__main__':
...
@@ -381,7 +403,7 @@ if __name__ == '__main__':
# "D:\project/20231108code-clean\pdf_cost_time\竖排例子\三国演义_繁体竖排版.pdf"
# "D:\project/20231108code-clean\pdf_cost_time\竖排例子\三国演义_繁体竖排版.pdf"
# "D:\project/20231108code-clean\pdf_cost_time\scihub\scihub_86800000\libgen.scimag86880000-86880999.zip_10.1021/acsami.1c03109.s002.pdf"
# "D:\project/20231108code-clean\pdf_cost_time\scihub\scihub_86800000\libgen.scimag86880000-86880999.zip_10.1021/acsami.1c03109.s002.pdf"
# "D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_18600000/libgen.scimag18645000-18645999.zip_10.1021/om3006239.pdf"
# "D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_18600000/libgen.scimag18645000-18645999.zip_10.1021/om3006239.pdf"
# file_content = read_file("D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_31000000/libgen.scimag31098000-31098999.zip_10.1109/isit.2006.261791.pdf","")
# file_content = read_file("D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_31000000/libgen.scimag31098000-31098999.zip_10.1109/isit.2006.261791.pdf","")
# noqa: E501
# file_content = read_file("D:\project/20231108code-clean\pdf_cost_time\竖排例子\净空法师_大乘无量寿.pdf","")
# file_content = read_file("D:\project/20231108code-clean\pdf_cost_time\竖排例子\净空法师_大乘无量寿.pdf","")
# doc = fitz.open("pdf", file_content)
# doc = fitz.open("pdf", file_content)
# text_layout_lst = get_pdf_text_layout_per_page(doc)
# text_layout_lst = get_pdf_text_layout_per_page(doc)
...
...
magic_pdf/integrations/rag/utils.py
View file @
ea2f8ea0
...
@@ -5,14 +5,13 @@ from pathlib import Path
...
@@ -5,14 +5,13 @@ from pathlib import Path
from
loguru
import
logger
from
loguru
import
logger
import
magic_pdf.model
as
model_config
import
magic_pdf.model
as
model_config
from
magic_pdf.config.ocr_content_type
import
BlockType
,
ContentType
from
magic_pdf.data.data_reader_writer
import
FileBasedDataReader
from
magic_pdf.dict2md.ocr_mkcontent
import
merge_para_with_text
from
magic_pdf.dict2md.ocr_mkcontent
import
merge_para_with_text
from
magic_pdf.integrations.rag.type
import
(
CategoryType
,
ContentObject
,
from
magic_pdf.integrations.rag.type
import
(
CategoryType
,
ContentObject
,
ElementRelation
,
ElementRelType
,
ElementRelation
,
ElementRelType
,
LayoutElements
,
LayoutElements
,
LayoutElementsExtra
,
PageInfo
)
LayoutElementsExtra
,
PageInfo
)
from
magic_pdf.libs.ocr_content_type
import
BlockType
,
ContentType
from
magic_pdf.rw.AbsReaderWriter
import
AbsReaderWriter
from
magic_pdf.rw.DiskReaderWriter
import
DiskReaderWriter
from
magic_pdf.tools.common
import
do_parse
,
prepare_env
from
magic_pdf.tools.common
import
do_parse
,
prepare_env
...
@@ -224,8 +223,8 @@ def inference(path, output_dir, method):
...
@@ -224,8 +223,8 @@ def inference(path, output_dir, method):
str
(
Path
(
path
).
stem
),
method
)
str
(
Path
(
path
).
stem
),
method
)
def
read_fn
(
path
):
def
read_fn
(
path
):
disk_rw
=
DiskReaderWrit
er
(
os
.
path
.
dirname
(
path
))
disk_rw
=
FileBasedDataRead
er
(
os
.
path
.
dirname
(
path
))
return
disk_rw
.
read
(
os
.
path
.
basename
(
path
)
,
AbsReaderWriter
.
MODE_BIN
)
return
disk_rw
.
read
(
os
.
path
.
basename
(
path
))
def
parse_doc
(
doc_path
:
str
):
def
parse_doc
(
doc_path
:
str
):
try
:
try
:
...
...
magic_pdf/libs/MakeContentConfig.py
deleted
100644 → 0
View file @
e4810cee
class
MakeMode
:
MM_MD
=
"mm_markdown"
NLP_MD
=
"nlp_markdown"
STANDARD_FORMAT
=
"standard_format"
class
DropMode
:
WHOLE_PDF
=
"whole_pdf"
SINGLE_PAGE
=
"single_page"
NONE
=
"none"
NONE_WITH_REASON
=
"none_with_reason"
magic_pdf/libs/config_reader.py
View file @
ea2f8ea0
...
@@ -5,7 +5,7 @@ import os
...
@@ -5,7 +5,7 @@ import os
from
loguru
import
logger
from
loguru
import
logger
from
magic_pdf.
libs.C
onstants
import
MODEL_NAME
from
magic_pdf.
config.c
onstants
import
MODEL_NAME
from
magic_pdf.libs.commons
import
parse_bucket_key
from
magic_pdf.libs.commons
import
parse_bucket_key
# 定义配置文件名常量
# 定义配置文件名常量
...
@@ -99,7 +99,7 @@ def get_table_recog_config():
...
@@ -99,7 +99,7 @@ def get_table_recog_config():
def
get_layout_config
():
def
get_layout_config
():
config
=
read_config
()
config
=
read_config
()
layout_config
=
config
.
get
(
"
layout-config
"
)
layout_config
=
config
.
get
(
'
layout-config
'
)
if
layout_config
is
None
:
if
layout_config
is
None
:
logger
.
warning
(
f
"'layout-config' not found in
{
CONFIG_FILE_NAME
}
, use '
{
MODEL_NAME
.
LAYOUTLMv3
}
' as default"
)
logger
.
warning
(
f
"'layout-config' not found in
{
CONFIG_FILE_NAME
}
, use '
{
MODEL_NAME
.
LAYOUTLMv3
}
' as default"
)
return
json
.
loads
(
f
'{{"model": "
{
MODEL_NAME
.
LAYOUTLMv3
}
"}}'
)
return
json
.
loads
(
f
'{{"model": "
{
MODEL_NAME
.
LAYOUTLMv3
}
"}}'
)
...
@@ -109,7 +109,7 @@ def get_layout_config():
...
@@ -109,7 +109,7 @@ def get_layout_config():
def
get_formula_config
():
def
get_formula_config
():
config
=
read_config
()
config
=
read_config
()
formula_config
=
config
.
get
(
"
formula-config
"
)
formula_config
=
config
.
get
(
'
formula-config
'
)
if
formula_config
is
None
:
if
formula_config
is
None
:
logger
.
warning
(
f
"'formula-config' not found in
{
CONFIG_FILE_NAME
}
, use 'True' as default"
)
logger
.
warning
(
f
"'formula-config' not found in
{
CONFIG_FILE_NAME
}
, use 'True' as default"
)
return
json
.
loads
(
f
'{{"mfd_model": "
{
MODEL_NAME
.
YOLO_V8_MFD
}
","mfr_model": "
{
MODEL_NAME
.
UniMerNet_v2_Small
}
","enable": true}}'
)
return
json
.
loads
(
f
'{{"mfd_model": "
{
MODEL_NAME
.
YOLO_V8_MFD
}
","mfr_model": "
{
MODEL_NAME
.
UniMerNet_v2_Small
}
","enable": true}}'
)
...
@@ -117,5 +117,5 @@ def get_formula_config():
...
@@ -117,5 +117,5 @@ def get_formula_config():
return
formula_config
return
formula_config
if
__name__
==
"
__main__
"
:
if
__name__
==
'
__main__
'
:
ak
,
sk
,
endpoint
=
get_s3_config
(
"
llm-raw
"
)
ak
,
sk
,
endpoint
=
get_s3_config
(
'
llm-raw
'
)
magic_pdf/libs/draw_bbox.py
View file @
ea2f8ea0
from
magic_pdf.config.constants
import
CROSS_PAGE
from
magic_pdf.config.ocr_content_type
import
(
BlockType
,
CategoryId
,
ContentType
)
from
magic_pdf.data.dataset
import
PymuDocDataset
from
magic_pdf.data.dataset
import
PymuDocDataset
from
magic_pdf.libs.commons
import
fitz
# PyMuPDF
from
magic_pdf.libs.commons
import
fitz
# PyMuPDF
from
magic_pdf.libs.Constants
import
CROSS_PAGE
from
magic_pdf.libs.ocr_content_type
import
BlockType
,
CategoryId
,
ContentType
from
magic_pdf.model.magic_model
import
MagicModel
from
magic_pdf.model.magic_model
import
MagicModel
...
...
magic_pdf/libs/drop_reason.py
deleted
100644 → 0
View file @
e4810cee
class
DropReason
:
TEXT_BLCOK_HOR_OVERLAP
=
"text_block_horizontal_overlap"
# 文字块有水平互相覆盖,导致无法准确定位文字顺序
USEFUL_BLOCK_HOR_OVERLAP
=
"useful_block_horizontal_overlap"
# 需保留的block水平覆盖
COMPLICATED_LAYOUT
=
"complicated_layout"
# 复杂的布局,暂时不支持
TOO_MANY_LAYOUT_COLUMNS
=
"too_many_layout_columns"
# 目前不支持分栏超过2列的
COLOR_BACKGROUND_TEXT_BOX
=
"color_background_text_box"
# 含有带色块的PDF,色块会改变阅读顺序,目前不支持带底色文字块的PDF。
HIGH_COMPUTATIONAL_lOAD_BY_IMGS
=
"high_computational_load_by_imgs"
# 含特殊图片,计算量太大,从而丢弃
HIGH_COMPUTATIONAL_lOAD_BY_SVGS
=
"high_computational_load_by_svgs"
# 特殊的SVG图,计算量太大,从而丢弃
HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES
=
"high_computational_load_by_total_pages"
# 计算量超过负荷,当前方法下计算量消耗过大
MISS_DOC_LAYOUT_RESULT
=
"missing doc_layout_result"
# 版面分析失败
Exception
=
"_exception"
# 解析中发生异常
ENCRYPTED
=
"encrypted"
# PDF是加密的
EMPTY_PDF
=
"total_page=0"
# PDF页面总数为0
NOT_IS_TEXT_PDF
=
"not_is_text_pdf"
# 不是文字版PDF,无法直接解析
DENSE_SINGLE_LINE_BLOCK
=
"dense_single_line_block"
# 无法清晰的分段
TITLE_DETECTION_FAILED
=
"title_detection_failed"
# 探测标题失败
TITLE_LEVEL_FAILED
=
"title_level_failed"
# 分析标题级别失败(例如一级、二级、三级标题)
PARA_SPLIT_FAILED
=
"para_split_failed"
# 识别段落失败
PARA_MERGE_FAILED
=
"para_merge_failed"
# 段落合并失败
NOT_ALLOW_LANGUAGE
=
"not_allow_language"
# 不支持的语种
SPECIAL_PDF
=
"special_pdf"
PSEUDO_SINGLE_COLUMN
=
"pseudo_single_column"
# 无法精确判断文字分栏
CAN_NOT_DETECT_PAGE_LAYOUT
=
"can_not_detect_page_layout"
# 无法分析页面的版面
NEGATIVE_BBOX_AREA
=
"negative_bbox_area"
# 缩放导致 bbox 面积为负
OVERLAP_BLOCKS_CAN_NOT_SEPARATION
=
"overlap_blocks_can_t_separation"
# 无法分离重叠的block
\ No newline at end of file
magic_pdf/libs/drop_tag.py
deleted
100644 → 0
View file @
e4810cee
COLOR_BG_HEADER_TXT_BLOCK
=
"color_background_header_txt_block"
PAGE_NO
=
"page-no"
# 页码
CONTENT_IN_FOOT_OR_HEADER
=
'in-foot-header-area'
# 页眉页脚内的文本
VERTICAL_TEXT
=
'vertical-text'
# 垂直文本
ROTATE_TEXT
=
'rotate-text'
# 旋转文本
EMPTY_SIDE_BLOCK
=
'empty-side-block'
# 边缘上的空白没有任何内容的block
ON_IMAGE_TEXT
=
'on-image-text'
# 文本在图片上
ON_TABLE_TEXT
=
'on-table-text'
# 文本在表格上
class
DropTag
:
PAGE_NUMBER
=
"page_no"
HEADER
=
"header"
FOOTER
=
"footer"
FOOTNOTE
=
"footnote"
NOT_IN_LAYOUT
=
"not_in_layout"
SPAN_OVERLAP
=
"span_overlap"
BLOCK_OVERLAP
=
"block_overlap"
magic_pdf/libs/pdf_image_tools.py
View file @
ea2f8ea0
from
magic_pdf.rw.AbsReaderWriter
import
AbsReaderWriter
from
magic_pdf.data.data_reader_writer
import
DataWriter
from
magic_pdf.libs.commons
import
fitz
from
magic_pdf.libs.commons
import
fitz
,
join_path
from
magic_pdf.libs.commons
import
join_path
from
magic_pdf.libs.hash_utils
import
compute_sha256
from
magic_pdf.libs.hash_utils
import
compute_sha256
def
cut_image
(
bbox
:
tuple
,
page_num
:
int
,
page
:
fitz
.
Page
,
return_path
,
imageWriter
:
AbsReaderWriter
):
def
cut_image
(
bbox
:
tuple
,
page_num
:
int
,
page
:
fitz
.
Page
,
return_path
,
imageWriter
:
DataWriter
):
"""
"""从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 save_path:需要同时支持s3和本地,
从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径
图片存放在save_path下,文件名是:
save_path:需要同时支持s3和本地, 图片存放在save_path下,文件名是: {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。
{page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。"""
"""
# 拼接文件名
# 拼接文件名
filename
=
f
"
{
page_num
}
_
{
int
(
bbox
[
0
])
}
_
{
int
(
bbox
[
1
])
}
_
{
int
(
bbox
[
2
])
}
_
{
int
(
bbox
[
3
])
}
"
filename
=
f
'
{
page_num
}
_
{
int
(
bbox
[
0
])
}
_
{
int
(
bbox
[
1
])
}
_
{
int
(
bbox
[
2
])
}
_
{
int
(
bbox
[
3
])
}
'
# 老版本返回不带bucket的路径
# 老版本返回不带bucket的路径
img_path
=
join_path
(
return_path
,
filename
)
if
return_path
is
not
None
else
None
img_path
=
join_path
(
return_path
,
filename
)
if
return_path
is
not
None
else
None
# 新版本生成平铺路径
# 新版本生成平铺路径
img_hash256_path
=
f
"
{
compute_sha256
(
img_path
)
}
.jpg
"
img_hash256_path
=
f
'
{
compute_sha256
(
img_path
)
}
.jpg
'
# 将坐标转换为fitz.Rect对象
# 将坐标转换为fitz.Rect对象
rect
=
fitz
.
Rect
(
*
bbox
)
rect
=
fitz
.
Rect
(
*
bbox
)
...
@@ -28,6 +26,6 @@ def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWri
...
@@ -28,6 +26,6 @@ def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWri
byte_data
=
pix
.
tobytes
(
output
=
'jpeg'
,
jpg_quality
=
95
)
byte_data
=
pix
.
tobytes
(
output
=
'jpeg'
,
jpg_quality
=
95
)
imageWriter
.
write
(
byte_data
,
img_hash256_path
,
AbsReaderWriter
.
MODE_BIN
)
imageWriter
.
write
(
img_hash256_path
,
byte_data
)
return
img_hash256_path
return
img_hash256_path
magic_pdf/model/magic_model.py
View file @
ea2f8ea0
import
enum
import
enum
import
json
import
json
from
magic_pdf.config.model_block_type
import
ModelBlockTypeEnum
from
magic_pdf.config.ocr_content_type
import
CategoryId
,
ContentType
from
magic_pdf.data.data_reader_writer
import
(
FileBasedDataReader
,
FileBasedDataWriter
)
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.boxbase
import
(
_is_in
,
_is_part_overlap
,
bbox_distance
,
from
magic_pdf.libs.boxbase
import
(
_is_in
,
_is_part_overlap
,
bbox_distance
,
bbox_relative_pos
,
box_area
,
calculate_iou
,
bbox_relative_pos
,
box_area
,
calculate_iou
,
...
@@ -9,11 +13,7 @@ from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
...
@@ -9,11 +13,7 @@ from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
from
magic_pdf.libs.commons
import
fitz
,
join_path
from
magic_pdf.libs.commons
import
fitz
,
join_path
from
magic_pdf.libs.coordinate_transform
import
get_scale_ratio
from
magic_pdf.libs.coordinate_transform
import
get_scale_ratio
from
magic_pdf.libs.local_math
import
float_gt
from
magic_pdf.libs.local_math
import
float_gt
from
magic_pdf.libs.ModelBlockTypeEnum
import
ModelBlockTypeEnum
from
magic_pdf.libs.ocr_content_type
import
CategoryId
,
ContentType
from
magic_pdf.pre_proc.remove_bbox_overlap
import
_remove_overlap_between_bbox
from
magic_pdf.pre_proc.remove_bbox_overlap
import
_remove_overlap_between_bbox
from
magic_pdf.rw.AbsReaderWriter
import
AbsReaderWriter
from
magic_pdf.rw.DiskReaderWriter
import
DiskReaderWriter
CAPATION_OVERLAP_AREA_RATIO
=
0.6
CAPATION_OVERLAP_AREA_RATIO
=
0.6
MERGE_BOX_OVERLAP_AREA_RATIO
=
1.1
MERGE_BOX_OVERLAP_AREA_RATIO
=
1.1
...
@@ -1050,27 +1050,27 @@ class MagicModel:
...
@@ -1050,27 +1050,27 @@ class MagicModel:
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
drw
=
DiskReaderWrit
er
(
r
'D:/project/20231108code-clean'
)
drw
=
FileBasedDataRead
er
(
r
'D:/project/20231108code-clean'
)
if
0
:
if
0
:
pdf_file_path
=
r
'linshixuqiu\19983-00.pdf'
pdf_file_path
=
r
'linshixuqiu\19983-00.pdf'
model_file_path
=
r
'linshixuqiu\19983-00_new.json'
model_file_path
=
r
'linshixuqiu\19983-00_new.json'
pdf_bytes
=
drw
.
read
(
pdf_file_path
,
AbsReaderWriter
.
MODE_BIN
)
pdf_bytes
=
drw
.
read
(
pdf_file_path
)
model_json_txt
=
drw
.
read
(
model_file_path
,
AbsReaderWriter
.
MODE_TXT
)
model_json_txt
=
drw
.
read
(
model_file_path
).
decode
(
)
model_list
=
json
.
loads
(
model_json_txt
)
model_list
=
json
.
loads
(
model_json_txt
)
write_path
=
r
'D:\project\20231108code-clean\linshixuqiu\19983-00'
write_path
=
r
'D:\project\20231108code-clean\linshixuqiu\19983-00'
img_bucket_path
=
'imgs'
img_bucket_path
=
'imgs'
img_writer
=
DiskReader
Writer
(
join_path
(
write_path
,
img_bucket_path
))
img_writer
=
FileBasedData
Writer
(
join_path
(
write_path
,
img_bucket_path
))
pdf_docs
=
fitz
.
open
(
'pdf'
,
pdf_bytes
)
pdf_docs
=
fitz
.
open
(
'pdf'
,
pdf_bytes
)
magic_model
=
MagicModel
(
model_list
,
pdf_docs
)
magic_model
=
MagicModel
(
model_list
,
pdf_docs
)
if
1
:
if
1
:
from
magic_pdf.data.dataset
import
PymuDocDataset
model_list
=
json
.
loads
(
model_list
=
json
.
loads
(
drw
.
read
(
'/opt/data/pdf/20240418/j.chroma.2009.03.042.json'
)
drw
.
read
(
'/opt/data/pdf/20240418/j.chroma.2009.03.042.json'
)
)
)
pdf_bytes
=
drw
.
read
(
pdf_bytes
=
drw
.
read
(
'/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf'
)
'/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf'
,
AbsReaderWriter
.
MODE_BIN
)
magic_model
=
MagicModel
(
model_list
,
PymuDocDataset
(
pdf_bytes
))
pdf_docs
=
fitz
.
open
(
'pdf'
,
pdf_bytes
)
magic_model
=
MagicModel
(
model_list
,
pdf_docs
)
for
i
in
range
(
7
):
for
i
in
range
(
7
):
print
(
magic_model
.
get_imgs
(
i
))
print
(
magic_model
.
get_imgs
(
i
))
magic_pdf/model/pdf_extract_kit.py
View file @
ea2f8ea0
import
numpy
as
np
# flake8: noqa
import
torch
from
loguru
import
logger
import
os
import
os
import
time
import
time
import
cv2
import
cv2
import
numpy
as
np
import
torch
import
yaml
import
yaml
from
loguru
import
logger
from
PIL
import
Image
from
PIL
import
Image
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
...
@@ -13,20 +15,21 @@ os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
...
@@ -13,20 +15,21 @@ os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try
:
try
:
import
torchtext
import
torchtext
if
torchtext
.
__version__
>=
"
0.18.0
"
:
if
torchtext
.
__version__
>=
'
0.18.0
'
:
torchtext
.
disable_torchtext_deprecation_warning
()
torchtext
.
disable_torchtext_deprecation_warning
()
except
ImportError
:
except
ImportError
:
pass
pass
from
magic_pdf.
libs.C
onstants
import
*
from
magic_pdf.
config.c
onstants
import
*
from
magic_pdf.model.model_list
import
AtomicModel
from
magic_pdf.model.model_list
import
AtomicModel
from
magic_pdf.model.sub_modules.model_init
import
AtomModelSingleton
from
magic_pdf.model.sub_modules.model_init
import
AtomModelSingleton
from
magic_pdf.model.sub_modules.model_utils
import
get_res_list_from_layout_res
,
crop_img
,
clean_vram
from
magic_pdf.model.sub_modules.model_utils
import
(
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
get_adjusted_mfdetrec_res
,
get_ocr_result_list
clean_vram
,
crop_img
,
get_res_list_from_layout_res
)
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
class
CustomPEKModel
:
class
CustomPEKModel
:
def
__init__
(
self
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
**
kwargs
):
def
__init__
(
self
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
**
kwargs
):
"""
"""
======== model init ========
======== model init ========
...
@@ -41,42 +44,54 @@ class CustomPEKModel:
...
@@ -41,42 +44,54 @@ class CustomPEKModel:
model_config_dir
=
os
.
path
.
join
(
root_dir
,
'resources'
,
'model_config'
)
model_config_dir
=
os
.
path
.
join
(
root_dir
,
'resources'
,
'model_config'
)
# 构建 model_configs.yaml 文件的完整路径
# 构建 model_configs.yaml 文件的完整路径
config_path
=
os
.
path
.
join
(
model_config_dir
,
'model_configs.yaml'
)
config_path
=
os
.
path
.
join
(
model_config_dir
,
'model_configs.yaml'
)
with
open
(
config_path
,
"r"
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
config_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
self
.
configs
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
self
.
configs
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
# 初始化解析配置
# 初始化解析配置
# layout config
# layout config
self
.
layout_config
=
kwargs
.
get
(
"layout_config"
)
self
.
layout_config
=
kwargs
.
get
(
'layout_config'
)
self
.
layout_model_name
=
self
.
layout_config
.
get
(
"model"
,
MODEL_NAME
.
DocLayout_YOLO
)
self
.
layout_model_name
=
self
.
layout_config
.
get
(
'model'
,
MODEL_NAME
.
DocLayout_YOLO
)
# formula config
# formula config
self
.
formula_config
=
kwargs
.
get
(
"formula_config"
)
self
.
formula_config
=
kwargs
.
get
(
'formula_config'
)
self
.
mfd_model_name
=
self
.
formula_config
.
get
(
"mfd_model"
,
MODEL_NAME
.
YOLO_V8_MFD
)
self
.
mfd_model_name
=
self
.
formula_config
.
get
(
self
.
mfr_model_name
=
self
.
formula_config
.
get
(
"mfr_model"
,
MODEL_NAME
.
UniMerNet_v2_Small
)
'mfd_model'
,
MODEL_NAME
.
YOLO_V8_MFD
self
.
apply_formula
=
self
.
formula_config
.
get
(
"enable"
,
True
)
)
self
.
mfr_model_name
=
self
.
formula_config
.
get
(
'mfr_model'
,
MODEL_NAME
.
UniMerNet_v2_Small
)
self
.
apply_formula
=
self
.
formula_config
.
get
(
'enable'
,
True
)
# table config
# table config
self
.
table_config
=
kwargs
.
get
(
"
table_config
"
)
self
.
table_config
=
kwargs
.
get
(
'
table_config
'
)
self
.
apply_table
=
self
.
table_config
.
get
(
"
enable
"
,
False
)
self
.
apply_table
=
self
.
table_config
.
get
(
'
enable
'
,
False
)
self
.
table_max_time
=
self
.
table_config
.
get
(
"
max_time
"
,
TABLE_MAX_TIME_VALUE
)
self
.
table_max_time
=
self
.
table_config
.
get
(
'
max_time
'
,
TABLE_MAX_TIME_VALUE
)
self
.
table_model_name
=
self
.
table_config
.
get
(
"
model
"
,
MODEL_NAME
.
RAPID_TABLE
)
self
.
table_model_name
=
self
.
table_config
.
get
(
'
model
'
,
MODEL_NAME
.
RAPID_TABLE
)
# ocr config
# ocr config
self
.
apply_ocr
=
ocr
self
.
apply_ocr
=
ocr
self
.
lang
=
kwargs
.
get
(
"
lang
"
,
None
)
self
.
lang
=
kwargs
.
get
(
'
lang
'
,
None
)
logger
.
info
(
logger
.
info
(
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
"apply_table: {}, table_model: {}, lang: {}"
.
format
(
'apply_table: {}, table_model: {}, lang: {}'
.
format
(
self
.
layout_model_name
,
self
.
apply_formula
,
self
.
apply_ocr
,
self
.
apply_table
,
self
.
table_model_name
,
self
.
layout_model_name
,
self
.
lang
self
.
apply_formula
,
self
.
apply_ocr
,
self
.
apply_table
,
self
.
table_model_name
,
self
.
lang
,
)
)
)
)
# 初始化解析方案
# 初始化解析方案
self
.
device
=
kwargs
.
get
(
"device"
,
"cpu"
)
self
.
device
=
kwargs
.
get
(
'device'
,
'cpu'
)
logger
.
info
(
"using device: {}"
.
format
(
self
.
device
))
logger
.
info
(
'using device: {}'
.
format
(
self
.
device
))
models_dir
=
kwargs
.
get
(
"models_dir"
,
os
.
path
.
join
(
root_dir
,
"resources"
,
"models"
))
models_dir
=
kwargs
.
get
(
logger
.
info
(
"using models_dir: {}"
.
format
(
models_dir
))
'models_dir'
,
os
.
path
.
join
(
root_dir
,
'resources'
,
'models'
)
)
logger
.
info
(
'using models_dir: {}'
.
format
(
models_dir
))
atom_model_manager
=
AtomModelSingleton
()
atom_model_manager
=
AtomModelSingleton
()
...
@@ -85,18 +100,24 @@ class CustomPEKModel:
...
@@ -85,18 +100,24 @@ class CustomPEKModel:
# 初始化公式检测模型
# 初始化公式检测模型
self
.
mfd_model
=
atom_model_manager
.
get_atom_model
(
self
.
mfd_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
MFD
,
atom_model_name
=
AtomicModel
.
MFD
,
mfd_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
self
.
mfd_model_name
])),
mfd_weights
=
str
(
device
=
self
.
device
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
mfd_model_name
]
)
),
device
=
self
.
device
,
)
)
# 初始化公式解析模型
# 初始化公式解析模型
mfr_weight_dir
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
self
.
mfr_model_name
]))
mfr_weight_dir
=
str
(
mfr_cfg_path
=
str
(
os
.
path
.
join
(
model_config_dir
,
"UniMERNet"
,
"demo.yaml"
))
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
mfr_model_name
])
)
mfr_cfg_path
=
str
(
os
.
path
.
join
(
model_config_dir
,
'UniMERNet'
,
'demo.yaml'
))
self
.
mfr_model
=
atom_model_manager
.
get_atom_model
(
self
.
mfr_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
MFR
,
atom_model_name
=
AtomicModel
.
MFR
,
mfr_weight_dir
=
mfr_weight_dir
,
mfr_weight_dir
=
mfr_weight_dir
,
mfr_cfg_path
=
mfr_cfg_path
,
mfr_cfg_path
=
mfr_cfg_path
,
device
=
self
.
device
device
=
self
.
device
,
)
)
# 初始化layout模型
# 初始化layout模型
...
@@ -104,16 +125,28 @@ class CustomPEKModel:
...
@@ -104,16 +125,28 @@ class CustomPEKModel:
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Layout
,
atom_model_name
=
AtomicModel
.
Layout
,
layout_model_name
=
MODEL_NAME
.
LAYOUTLMv3
,
layout_model_name
=
MODEL_NAME
.
LAYOUTLMv3
,
layout_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
layout_model_name
])),
layout_weights
=
str
(
layout_config_file
=
str
(
os
.
path
.
join
(
model_config_dir
,
"layoutlmv3"
,
"layoutlmv3_base_inference.yaml"
)),
os
.
path
.
join
(
device
=
self
.
device
models_dir
,
self
.
configs
[
'weights'
][
self
.
layout_model_name
]
)
),
layout_config_file
=
str
(
os
.
path
.
join
(
model_config_dir
,
'layoutlmv3'
,
'layoutlmv3_base_inference.yaml'
)
),
device
=
self
.
device
,
)
)
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Layout
,
atom_model_name
=
AtomicModel
.
Layout
,
layout_model_name
=
MODEL_NAME
.
DocLayout_YOLO
,
layout_model_name
=
MODEL_NAME
.
DocLayout_YOLO
,
doclayout_yolo_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
layout_model_name
])),
doclayout_yolo_weights
=
str
(
device
=
self
.
device
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
layout_model_name
]
)
),
device
=
self
.
device
,
)
)
# 初始化ocr
# 初始化ocr
if
self
.
apply_ocr
:
if
self
.
apply_ocr
:
...
@@ -121,23 +154,22 @@ class CustomPEKModel:
...
@@ -121,23 +154,22 @@ class CustomPEKModel:
atom_model_name
=
AtomicModel
.
OCR
,
atom_model_name
=
AtomicModel
.
OCR
,
ocr_show_log
=
show_log
,
ocr_show_log
=
show_log
,
det_db_box_thresh
=
0.3
,
det_db_box_thresh
=
0.3
,
lang
=
self
.
lang
lang
=
self
.
lang
,
)
)
# init table model
# init table model
if
self
.
apply_table
:
if
self
.
apply_table
:
table_model_dir
=
self
.
configs
[
"
weights
"
][
self
.
table_model_name
]
table_model_dir
=
self
.
configs
[
'
weights
'
][
self
.
table_model_name
]
self
.
table_model
=
atom_model_manager
.
get_atom_model
(
self
.
table_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Table
,
atom_model_name
=
AtomicModel
.
Table
,
table_model_name
=
self
.
table_model_name
,
table_model_name
=
self
.
table_model_name
,
table_model_path
=
str
(
os
.
path
.
join
(
models_dir
,
table_model_dir
)),
table_model_path
=
str
(
os
.
path
.
join
(
models_dir
,
table_model_dir
)),
table_max_time
=
self
.
table_max_time
,
table_max_time
=
self
.
table_max_time
,
device
=
self
.
device
device
=
self
.
device
,
)
)
logger
.
info
(
'DocAnalysis init done!'
)
logger
.
info
(
'DocAnalysis init done!'
)
def
__call__
(
self
,
image
):
def
__call__
(
self
,
image
):
page_start
=
time
.
time
()
page_start
=
time
.
time
()
# layout检测
# layout检测
...
@@ -150,7 +182,7 @@ class CustomPEKModel:
...
@@ -150,7 +182,7 @@ class CustomPEKModel:
# doclayout_yolo
# doclayout_yolo
layout_res
=
self
.
layout_model
.
predict
(
image
)
layout_res
=
self
.
layout_model
.
predict
(
image
)
layout_cost
=
round
(
time
.
time
()
-
layout_start
,
2
)
layout_cost
=
round
(
time
.
time
()
-
layout_start
,
2
)
logger
.
info
(
f
"
layout detection time:
{
layout_cost
}
"
)
logger
.
info
(
f
'
layout detection time:
{
layout_cost
}
'
)
pil_img
=
Image
.
fromarray
(
image
)
pil_img
=
Image
.
fromarray
(
image
)
...
@@ -158,32 +190,40 @@ class CustomPEKModel:
...
@@ -158,32 +190,40 @@ class CustomPEKModel:
# 公式检测
# 公式检测
mfd_start
=
time
.
time
()
mfd_start
=
time
.
time
()
mfd_res
=
self
.
mfd_model
.
predict
(
image
)
mfd_res
=
self
.
mfd_model
.
predict
(
image
)
logger
.
info
(
f
"
mfd time:
{
round
(
time
.
time
()
-
mfd_start
,
2
)
}
"
)
logger
.
info
(
f
'
mfd time:
{
round
(
time
.
time
()
-
mfd_start
,
2
)
}
'
)
# 公式识别
# 公式识别
mfr_start
=
time
.
time
()
mfr_start
=
time
.
time
()
formula_list
=
self
.
mfr_model
.
predict
(
mfd_res
,
image
)
formula_list
=
self
.
mfr_model
.
predict
(
mfd_res
,
image
)
layout_res
.
extend
(
formula_list
)
layout_res
.
extend
(
formula_list
)
mfr_cost
=
round
(
time
.
time
()
-
mfr_start
,
2
)
mfr_cost
=
round
(
time
.
time
()
-
mfr_start
,
2
)
logger
.
info
(
f
"
formula nums:
{
len
(
formula_list
)
}
, mfr time:
{
mfr_cost
}
"
)
logger
.
info
(
f
'
formula nums:
{
len
(
formula_list
)
}
, mfr time:
{
mfr_cost
}
'
)
# 清理显存
# 清理显存
clean_vram
(
self
.
device
,
vram_threshold
=
8
)
clean_vram
(
self
.
device
,
vram_threshold
=
8
)
# 从layout_res中获取ocr区域、表格区域、公式区域
# 从layout_res中获取ocr区域、表格区域、公式区域
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
=
get_res_list_from_layout_res
(
layout_res
)
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
=
(
get_res_list_from_layout_res
(
layout_res
)
)
# ocr识别
# ocr识别
if
self
.
apply_ocr
:
if
self
.
apply_ocr
:
ocr_start
=
time
.
time
()
ocr_start
=
time
.
time
()
# Process each area that requires OCR processing
# Process each area that requires OCR processing
for
res
in
ocr_res_list
:
for
res
in
ocr_res_list
:
new_image
,
useful_list
=
crop_img
(
res
,
pil_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
new_image
,
useful_list
=
crop_img
(
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
res
,
pil_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
# OCR recognition
# OCR recognition
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
),
cv2
.
COLOR_RGB2BGR
)
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
),
cv2
.
COLOR_RGB2BGR
)
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
# Integration results
# Integration results
if
ocr_res
:
if
ocr_res
:
...
@@ -191,7 +231,7 @@ class CustomPEKModel:
...
@@ -191,7 +231,7 @@ class CustomPEKModel:
layout_res
.
extend
(
ocr_result_list
)
layout_res
.
extend
(
ocr_result_list
)
ocr_cost
=
round
(
time
.
time
()
-
ocr_start
,
2
)
ocr_cost
=
round
(
time
.
time
()
-
ocr_start
,
2
)
logger
.
info
(
f
"
ocr time:
{
ocr_cost
}
"
)
logger
.
info
(
f
'
ocr time:
{
ocr_cost
}
'
)
# 表格识别 table recognition
# 表格识别 table recognition
if
self
.
apply_table
:
if
self
.
apply_table
:
...
@@ -202,27 +242,37 @@ class CustomPEKModel:
...
@@ -202,27 +242,37 @@ class CustomPEKModel:
html_code
=
None
html_code
=
None
if
self
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
if
self
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
table_result
=
self
.
table_model
.
predict
(
new_image
,
"
html
"
)
table_result
=
self
.
table_model
.
predict
(
new_image
,
'
html
'
)
if
len
(
table_result
)
>
0
:
if
len
(
table_result
)
>
0
:
html_code
=
table_result
[
0
]
html_code
=
table_result
[
0
]
elif
self
.
table_model_name
==
MODEL_NAME
.
TABLE_MASTER
:
elif
self
.
table_model_name
==
MODEL_NAME
.
TABLE_MASTER
:
html_code
=
self
.
table_model
.
img2html
(
new_image
)
html_code
=
self
.
table_model
.
img2html
(
new_image
)
elif
self
.
table_model_name
==
MODEL_NAME
.
RAPID_TABLE
:
elif
self
.
table_model_name
==
MODEL_NAME
.
RAPID_TABLE
:
html_code
,
table_cell_bboxes
,
elapse
=
self
.
table_model
.
predict
(
new_image
)
html_code
,
table_cell_bboxes
,
elapse
=
self
.
table_model
.
predict
(
new_image
)
run_time
=
time
.
time
()
-
single_table_start_time
run_time
=
time
.
time
()
-
single_table_start_time
if
run_time
>
self
.
table_max_time
:
if
run_time
>
self
.
table_max_time
:
logger
.
warning
(
f
"table recognition processing exceeds max time
{
self
.
table_max_time
}
s"
)
logger
.
warning
(
f
'table recognition processing exceeds max time
{
self
.
table_max_time
}
s'
)
# 判断是否返回正常
# 判断是否返回正常
if
html_code
:
if
html_code
:
expected_ending
=
html_code
.
strip
().
endswith
(
'</html>'
)
or
html_code
.
strip
().
endswith
(
'</table>'
)
expected_ending
=
html_code
.
strip
().
endswith
(
'</html>'
)
or
html_code
.
strip
().
endswith
(
'</table>'
)
if
expected_ending
:
if
expected_ending
:
res
[
"
html
"
]
=
html_code
res
[
'
html
'
]
=
html_code
else
:
else
:
logger
.
warning
(
f
"table recognition processing fails, not found expected HTML table end"
)
logger
.
warning
(
'table recognition processing fails, not found expected HTML table end'
)
else
:
else
:
logger
.
warning
(
f
"table recognition processing fails, not get html return"
)
logger
.
warning
(
logger
.
info
(
f
"table time:
{
round
(
time
.
time
()
-
table_start
,
2
)
}
"
)
'table recognition processing fails, not get html return'
)
logger
.
info
(
f
'table time:
{
round
(
time
.
time
()
-
table_start
,
2
)
}
'
)
logger
.
info
(
f
"
-----page total time:
{
round
(
time
.
time
()
-
page_start
,
2
)
}
-----
"
)
logger
.
info
(
f
'
-----page total time:
{
round
(
time
.
time
()
-
page_start
,
2
)
}
-----
'
)
return
layout_res
return
layout_res
magic_pdf/model/sub_modules/model_init.py
View file @
ea2f8ea0
from
loguru
import
logger
from
loguru
import
logger
from
magic_pdf.
libs.C
onstants
import
MODEL_NAME
from
magic_pdf.
config.c
onstants
import
MODEL_NAME
from
magic_pdf.model.model_list
import
AtomicModel
from
magic_pdf.model.model_list
import
AtomicModel
from
magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO
import
DocLayoutYOLOModel
from
magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO
import
\
from
magic_pdf.model.sub_modules.layout.layoutlmv3.model_init
import
Layoutlmv3_Predictor
DocLayoutYOLOModel
from
magic_pdf.model.sub_modules.layout.layoutlmv3.model_init
import
\
Layoutlmv3_Predictor
from
magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8
import
YOLOv8MFDModel
from
magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8
import
YOLOv8MFDModel
from
magic_pdf.model.sub_modules.mfr.unimernet.Unimernet
import
UnimernetModel
from
magic_pdf.model.sub_modules.mfr.unimernet.Unimernet
import
UnimernetModel
from
magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod
import
ModifiedPaddleOCR
from
magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod
import
\
ModifiedPaddleOCR
from
magic_pdf.model.sub_modules.table.rapidtable.rapid_table
import
\
RapidTableModel
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from
magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable
import
StructTableModel
from
magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable
import
\
from
magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle
import
TableMasterPaddleModel
StructTableModel
from
magic_pdf.model.sub_modules.table.rapidtable.rapid_table
import
RapidTableModel
from
magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle
import
\
TableMasterPaddleModel
def
table_model_init
(
table_model_type
,
model_path
,
max_time
,
_device_
=
'cpu'
):
def
table_model_init
(
table_model_type
,
model_path
,
max_time
,
_device_
=
'cpu'
):
...
@@ -19,14 +24,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
...
@@ -19,14 +24,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
table_model
=
StructTableModel
(
model_path
,
max_new_tokens
=
2048
,
max_time
=
max_time
)
table_model
=
StructTableModel
(
model_path
,
max_new_tokens
=
2048
,
max_time
=
max_time
)
elif
table_model_type
==
MODEL_NAME
.
TABLE_MASTER
:
elif
table_model_type
==
MODEL_NAME
.
TABLE_MASTER
:
config
=
{
config
=
{
"
model_dir
"
:
model_path
,
'
model_dir
'
:
model_path
,
"
device
"
:
_device_
'
device
'
:
_device_
}
}
table_model
=
TableMasterPaddleModel
(
config
)
table_model
=
TableMasterPaddleModel
(
config
)
elif
table_model_type
==
MODEL_NAME
.
RAPID_TABLE
:
elif
table_model_type
==
MODEL_NAME
.
RAPID_TABLE
:
table_model
=
RapidTableModel
()
table_model
=
RapidTableModel
()
else
:
else
:
logger
.
error
(
"
table model type not allow
"
)
logger
.
error
(
'
table model type not allow
'
)
exit
(
1
)
exit
(
1
)
return
table_model
return
table_model
...
@@ -87,8 +92,8 @@ class AtomModelSingleton:
...
@@ -87,8 +92,8 @@ class AtomModelSingleton:
return
cls
.
_instance
return
cls
.
_instance
def
get_atom_model
(
self
,
atom_model_name
:
str
,
**
kwargs
):
def
get_atom_model
(
self
,
atom_model_name
:
str
,
**
kwargs
):
lang
=
kwargs
.
get
(
"
lang
"
,
None
)
lang
=
kwargs
.
get
(
'
lang
'
,
None
)
layout_model_name
=
kwargs
.
get
(
"
layout_model_name
"
,
None
)
layout_model_name
=
kwargs
.
get
(
'
layout_model_name
'
,
None
)
key
=
(
atom_model_name
,
layout_model_name
,
lang
)
key
=
(
atom_model_name
,
layout_model_name
,
lang
)
if
key
not
in
self
.
_models
:
if
key
not
in
self
.
_models
:
self
.
_models
[
key
]
=
atom_model_init
(
model_name
=
atom_model_name
,
**
kwargs
)
self
.
_models
[
key
]
=
atom_model_init
(
model_name
=
atom_model_name
,
**
kwargs
)
...
@@ -98,47 +103,47 @@ class AtomModelSingleton:
...
@@ -98,47 +103,47 @@ class AtomModelSingleton:
def
atom_model_init
(
model_name
:
str
,
**
kwargs
):
def
atom_model_init
(
model_name
:
str
,
**
kwargs
):
atom_model
=
None
atom_model
=
None
if
model_name
==
AtomicModel
.
Layout
:
if
model_name
==
AtomicModel
.
Layout
:
if
kwargs
.
get
(
"
layout_model_name
"
)
==
MODEL_NAME
.
LAYOUTLMv3
:
if
kwargs
.
get
(
'
layout_model_name
'
)
==
MODEL_NAME
.
LAYOUTLMv3
:
atom_model
=
layout_model_init
(
atom_model
=
layout_model_init
(
kwargs
.
get
(
"
layout_weights
"
),
kwargs
.
get
(
'
layout_weights
'
),
kwargs
.
get
(
"
layout_config_file
"
),
kwargs
.
get
(
'
layout_config_file
'
),
kwargs
.
get
(
"
device
"
)
kwargs
.
get
(
'
device
'
)
)
)
elif
kwargs
.
get
(
"
layout_model_name
"
)
==
MODEL_NAME
.
DocLayout_YOLO
:
elif
kwargs
.
get
(
'
layout_model_name
'
)
==
MODEL_NAME
.
DocLayout_YOLO
:
atom_model
=
doclayout_yolo_model_init
(
atom_model
=
doclayout_yolo_model_init
(
kwargs
.
get
(
"
doclayout_yolo_weights
"
),
kwargs
.
get
(
'
doclayout_yolo_weights
'
),
kwargs
.
get
(
"
device
"
)
kwargs
.
get
(
'
device
'
)
)
)
elif
model_name
==
AtomicModel
.
MFD
:
elif
model_name
==
AtomicModel
.
MFD
:
atom_model
=
mfd_model_init
(
atom_model
=
mfd_model_init
(
kwargs
.
get
(
"
mfd_weights
"
),
kwargs
.
get
(
'
mfd_weights
'
),
kwargs
.
get
(
"
device
"
)
kwargs
.
get
(
'
device
'
)
)
)
elif
model_name
==
AtomicModel
.
MFR
:
elif
model_name
==
AtomicModel
.
MFR
:
atom_model
=
mfr_model_init
(
atom_model
=
mfr_model_init
(
kwargs
.
get
(
"
mfr_weight_dir
"
),
kwargs
.
get
(
'
mfr_weight_dir
'
),
kwargs
.
get
(
"
mfr_cfg_path
"
),
kwargs
.
get
(
'
mfr_cfg_path
'
),
kwargs
.
get
(
"
device
"
)
kwargs
.
get
(
'
device
'
)
)
)
elif
model_name
==
AtomicModel
.
OCR
:
elif
model_name
==
AtomicModel
.
OCR
:
atom_model
=
ocr_model_init
(
atom_model
=
ocr_model_init
(
kwargs
.
get
(
"
ocr_show_log
"
),
kwargs
.
get
(
'
ocr_show_log
'
),
kwargs
.
get
(
"
det_db_box_thresh
"
),
kwargs
.
get
(
'
det_db_box_thresh
'
),
kwargs
.
get
(
"
lang
"
)
kwargs
.
get
(
'
lang
'
)
)
)
elif
model_name
==
AtomicModel
.
Table
:
elif
model_name
==
AtomicModel
.
Table
:
atom_model
=
table_model_init
(
atom_model
=
table_model_init
(
kwargs
.
get
(
"
table_model_name
"
),
kwargs
.
get
(
'
table_model_name
'
),
kwargs
.
get
(
"
table_model_path
"
),
kwargs
.
get
(
'
table_model_path
'
),
kwargs
.
get
(
"
table_max_time
"
),
kwargs
.
get
(
'
table_max_time
'
),
kwargs
.
get
(
"
device
"
)
kwargs
.
get
(
'
device
'
)
)
)
else
:
else
:
logger
.
error
(
"
model name not allow
"
)
logger
.
error
(
'
model name not allow
'
)
exit
(
1
)
exit
(
1
)
if
atom_model
is
None
:
if
atom_model
is
None
:
logger
.
error
(
"
model init failed
"
)
logger
.
error
(
'
model init failed
'
)
exit
(
1
)
exit
(
1
)
else
:
else
:
return
atom_model
return
atom_model
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment