Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
3a42ebbf
Unverified
Commit
3a42ebbf
authored
Nov 01, 2024
by
Xiaomeng Zhao
Committed by
GitHub
Nov 01, 2024
Browse files
Merge pull request #838 from opendatalab/release-0.9.0
Release 0.9.0
parents
765c6d77
14024793
Changes
591
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1312 additions
and
65 deletions
+1312
-65
magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py
...f/model/pek_sub_modules/structeqtable/StructTableModel.py
+8
-1
magic_pdf/model/ppTableModel.py
magic_pdf/model/ppTableModel.py
+2
-2
magic_pdf/model/pp_structure_v2.py
magic_pdf/model/pp_structure_v2.py
+5
-2
magic_pdf/model/v3/__init__.py
magic_pdf/model/v3/__init__.py
+0
-0
magic_pdf/model/v3/helpers.py
magic_pdf/model/v3/helpers.py
+125
-0
magic_pdf/para/para_split_v3.py
magic_pdf/para/para_split_v3.py
+296
-0
magic_pdf/pdf_parse_by_ocr.py
magic_pdf/pdf_parse_by_ocr.py
+6
-3
magic_pdf/pdf_parse_by_txt.py
magic_pdf/pdf_parse_by_txt.py
+6
-3
magic_pdf/pdf_parse_union_core_v2.py
magic_pdf/pdf_parse_union_core_v2.py
+644
-0
magic_pdf/pipe/AbsPipe.py
magic_pdf/pipe/AbsPipe.py
+5
-1
magic_pdf/pipe/OCRPipe.py
magic_pdf/pipe/OCRPipe.py
+10
-4
magic_pdf/pipe/TXTPipe.py
magic_pdf/pipe/TXTPipe.py
+10
-4
magic_pdf/pipe/UNIPipe.py
magic_pdf/pipe/UNIPipe.py
+16
-7
magic_pdf/pre_proc/ocr_detect_all_bboxes.py
magic_pdf/pre_proc/ocr_detect_all_bboxes.py
+83
-1
magic_pdf/pre_proc/ocr_dict_merge.py
magic_pdf/pre_proc/ocr_dict_merge.py
+27
-2
magic_pdf/resources/model_config/UniMERNet/demo.yaml
magic_pdf/resources/model_config/UniMERNet/demo.yaml
+7
-7
magic_pdf/resources/model_config/model_configs.yaml
magic_pdf/resources/model_config/model_configs.yaml
+5
-13
magic_pdf/tools/cli.py
magic_pdf/tools/cli.py
+14
-1
magic_pdf/tools/common.py
magic_pdf/tools/common.py
+18
-8
magic_pdf/user_api.py
magic_pdf/user_api.py
+25
-6
No files found.
magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py
View file @
3a42ebbf
from
struct_eqtable.model
import
StructTable
from
loguru
import
logger
try
:
from
struct_eqtable.model
import
StructTable
except
ImportError
:
logger
.
error
(
"StructEqTable is under upgrade, the current version does not support it."
)
from
pypandoc
import
convert_text
class
StructTableModel
:
def
__init__
(
self
,
model_path
,
max_new_tokens
=
2048
,
max_time
=
400
,
device
=
'cpu'
):
# init
...
...
magic_pdf/model/ppTableModel.py
View file @
3a42ebbf
...
...
@@ -52,11 +52,11 @@ class ppTableModel(object):
rec_model_dir
=
os
.
path
.
join
(
model_dir
,
REC_MODEL_DIR
)
rec_char_dict_path
=
os
.
path
.
join
(
model_dir
,
REC_CHAR_DICT
)
device
=
kwargs
.
get
(
"device"
,
"cpu"
)
use_gpu
=
True
if
device
==
"cuda"
else
False
use_gpu
=
True
if
device
.
startswith
(
"cuda"
)
else
False
config
=
{
"use_gpu"
:
use_gpu
,
"table_max_len"
:
kwargs
.
get
(
"table_max_len"
,
TABLE_MAX_LEN
),
"table_algorithm"
:
TABLE_MASTER
,
"table_algorithm"
:
"TableMaster"
,
"table_model_dir"
:
table_model_dir
,
"table_char_dict_path"
:
table_char_dict_path
,
"det_model_dir"
:
det_model_dir
,
...
...
magic_pdf/model/pp_structure_v2.py
View file @
3a42ebbf
...
...
@@ -18,7 +18,10 @@ def region_to_bbox(region):
class
CustomPaddleModel
:
def
__init__
(
self
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
):
def
__init__
(
self
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
lang
=
None
):
if
lang
is
not
None
:
self
.
model
=
PPStructure
(
table
=
False
,
ocr
=
ocr
,
show_log
=
show_log
,
lang
=
lang
)
else
:
self
.
model
=
PPStructure
(
table
=
False
,
ocr
=
ocr
,
show_log
=
show_log
)
def
__call__
(
self
,
img
):
...
...
magic_pdf/model/v3/__init__.py
0 → 100644
View file @
3a42ebbf
magic_pdf/model/v3/helpers.py
0 → 100644
View file @
3a42ebbf
from
collections
import
defaultdict
from
typing
import
List
,
Dict
import
torch
from
transformers
import
LayoutLMv3ForTokenClassification
MAX_LEN
=
510
CLS_TOKEN_ID
=
0
UNK_TOKEN_ID
=
3
EOS_TOKEN_ID
=
2
class
DataCollator
:
def
__call__
(
self
,
features
:
List
[
dict
])
->
Dict
[
str
,
torch
.
Tensor
]:
bbox
=
[]
labels
=
[]
input_ids
=
[]
attention_mask
=
[]
# clip bbox and labels to max length, build input_ids and attention_mask
for
feature
in
features
:
_bbox
=
feature
[
"source_boxes"
]
if
len
(
_bbox
)
>
MAX_LEN
:
_bbox
=
_bbox
[:
MAX_LEN
]
_labels
=
feature
[
"target_index"
]
if
len
(
_labels
)
>
MAX_LEN
:
_labels
=
_labels
[:
MAX_LEN
]
_input_ids
=
[
UNK_TOKEN_ID
]
*
len
(
_bbox
)
_attention_mask
=
[
1
]
*
len
(
_bbox
)
assert
len
(
_bbox
)
==
len
(
_labels
)
==
len
(
_input_ids
)
==
len
(
_attention_mask
)
bbox
.
append
(
_bbox
)
labels
.
append
(
_labels
)
input_ids
.
append
(
_input_ids
)
attention_mask
.
append
(
_attention_mask
)
# add CLS and EOS tokens
for
i
in
range
(
len
(
bbox
)):
bbox
[
i
]
=
[[
0
,
0
,
0
,
0
]]
+
bbox
[
i
]
+
[[
0
,
0
,
0
,
0
]]
labels
[
i
]
=
[
-
100
]
+
labels
[
i
]
+
[
-
100
]
input_ids
[
i
]
=
[
CLS_TOKEN_ID
]
+
input_ids
[
i
]
+
[
EOS_TOKEN_ID
]
attention_mask
[
i
]
=
[
1
]
+
attention_mask
[
i
]
+
[
1
]
# padding to max length
max_len
=
max
(
len
(
x
)
for
x
in
bbox
)
for
i
in
range
(
len
(
bbox
)):
bbox
[
i
]
=
bbox
[
i
]
+
[[
0
,
0
,
0
,
0
]]
*
(
max_len
-
len
(
bbox
[
i
]))
labels
[
i
]
=
labels
[
i
]
+
[
-
100
]
*
(
max_len
-
len
(
labels
[
i
]))
input_ids
[
i
]
=
input_ids
[
i
]
+
[
EOS_TOKEN_ID
]
*
(
max_len
-
len
(
input_ids
[
i
]))
attention_mask
[
i
]
=
attention_mask
[
i
]
+
[
0
]
*
(
max_len
-
len
(
attention_mask
[
i
])
)
ret
=
{
"bbox"
:
torch
.
tensor
(
bbox
),
"attention_mask"
:
torch
.
tensor
(
attention_mask
),
"labels"
:
torch
.
tensor
(
labels
),
"input_ids"
:
torch
.
tensor
(
input_ids
),
}
# set label > MAX_LEN to -100, because original labels may be > MAX_LEN
ret
[
"labels"
][
ret
[
"labels"
]
>
MAX_LEN
]
=
-
100
# set label > 0 to label-1, because original labels are 1-indexed
ret
[
"labels"
][
ret
[
"labels"
]
>
0
]
-=
1
return
ret
def
boxes2inputs
(
boxes
:
List
[
List
[
int
]])
->
Dict
[
str
,
torch
.
Tensor
]:
bbox
=
[[
0
,
0
,
0
,
0
]]
+
boxes
+
[[
0
,
0
,
0
,
0
]]
input_ids
=
[
CLS_TOKEN_ID
]
+
[
UNK_TOKEN_ID
]
*
len
(
boxes
)
+
[
EOS_TOKEN_ID
]
attention_mask
=
[
1
]
+
[
1
]
*
len
(
boxes
)
+
[
1
]
return
{
"bbox"
:
torch
.
tensor
([
bbox
]),
"attention_mask"
:
torch
.
tensor
([
attention_mask
]),
"input_ids"
:
torch
.
tensor
([
input_ids
]),
}
def
prepare_inputs
(
inputs
:
Dict
[
str
,
torch
.
Tensor
],
model
:
LayoutLMv3ForTokenClassification
)
->
Dict
[
str
,
torch
.
Tensor
]:
ret
=
{}
for
k
,
v
in
inputs
.
items
():
v
=
v
.
to
(
model
.
device
)
if
torch
.
is_floating_point
(
v
):
v
=
v
.
to
(
model
.
dtype
)
ret
[
k
]
=
v
return
ret
def
parse_logits
(
logits
:
torch
.
Tensor
,
length
:
int
)
->
List
[
int
]:
"""
parse logits to orders
:param logits: logits from model
:param length: input length
:return: orders
"""
logits
=
logits
[
1
:
length
+
1
,
:
length
]
orders
=
logits
.
argsort
(
descending
=
False
).
tolist
()
ret
=
[
o
.
pop
()
for
o
in
orders
]
while
True
:
order_to_idxes
=
defaultdict
(
list
)
for
idx
,
order
in
enumerate
(
ret
):
order_to_idxes
[
order
].
append
(
idx
)
# filter idxes len > 1
order_to_idxes
=
{
k
:
v
for
k
,
v
in
order_to_idxes
.
items
()
if
len
(
v
)
>
1
}
if
not
order_to_idxes
:
break
# filter
for
order
,
idxes
in
order_to_idxes
.
items
():
# find original logits of idxes
idxes_to_logit
=
{}
for
idx
in
idxes
:
idxes_to_logit
[
idx
]
=
logits
[
idx
,
order
]
idxes_to_logit
=
sorted
(
idxes_to_logit
.
items
(),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
# keep the highest logit as order, set others to next candidate
for
idx
,
_
in
idxes_to_logit
[
1
:]:
ret
[
idx
]
=
orders
[
idx
].
pop
()
return
ret
def
check_duplicate
(
a
:
List
[
int
])
->
bool
:
return
len
(
a
)
!=
len
(
set
(
a
))
magic_pdf/para/para_split_v3.py
0 → 100644
View file @
3a42ebbf
import
copy
from
loguru
import
logger
from
magic_pdf.libs.Constants
import
LINES_DELETED
,
CROSS_PAGE
from
magic_pdf.libs.ocr_content_type
import
BlockType
,
ContentType
LINE_STOP_FLAG
=
(
'.'
,
'!'
,
'?'
,
'。'
,
'!'
,
'?'
,
')'
,
')'
,
'"'
,
'”'
,
':'
,
':'
,
';'
,
';'
)
LIST_END_FLAG
=
(
'.'
,
'。'
,
';'
,
';'
)
class
ListLineTag
:
IS_LIST_START_LINE
=
"is_list_start_line"
IS_LIST_END_LINE
=
"is_list_end_line"
def
__process_blocks
(
blocks
):
# 对所有block预处理
# 1.通过title和interline_equation将block分组
# 2.bbox边界根据line信息重置
result
=
[]
current_group
=
[]
for
i
in
range
(
len
(
blocks
)):
current_block
=
blocks
[
i
]
# 如果当前块是 text 类型
if
current_block
[
'type'
]
==
'text'
:
current_block
[
"bbox_fs"
]
=
copy
.
deepcopy
(
current_block
[
"bbox"
])
if
'lines'
in
current_block
and
len
(
current_block
[
"lines"
])
>
0
:
current_block
[
'bbox_fs'
]
=
[
min
([
line
[
'bbox'
][
0
]
for
line
in
current_block
[
'lines'
]]),
min
([
line
[
'bbox'
][
1
]
for
line
in
current_block
[
'lines'
]]),
max
([
line
[
'bbox'
][
2
]
for
line
in
current_block
[
'lines'
]]),
max
([
line
[
'bbox'
][
3
]
for
line
in
current_block
[
'lines'
]])]
current_group
.
append
(
current_block
)
# 检查下一个块是否存在
if
i
+
1
<
len
(
blocks
):
next_block
=
blocks
[
i
+
1
]
# 如果下一个块不是 text 类型且是 title 或 interline_equation 类型
if
next_block
[
'type'
]
in
[
'title'
,
'interline_equation'
]:
result
.
append
(
current_group
)
current_group
=
[]
# 处理最后一个 group
if
current_group
:
result
.
append
(
current_group
)
return
result
def
__is_list_or_index_block
(
block
):
# 一个block如果是list block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 右侧不顶格(狗牙状)
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.多个line以endflag结尾
# 1.block内有多个line 2.block 内有多个line左侧顶格写 3.block内有多个line 左侧不顶格
# index block 是一种特殊的list block
# 一个block如果是index block 应该同时满足以下特征
# 1.block内有多个line 2.block 内有多个line两侧均顶格写 3.line的开头或者结尾均为数字
if
len
(
block
[
'lines'
])
>=
2
:
first_line
=
block
[
'lines'
][
0
]
line_height
=
first_line
[
'bbox'
][
3
]
-
first_line
[
'bbox'
][
1
]
block_weight
=
block
[
'bbox_fs'
][
2
]
-
block
[
'bbox_fs'
][
0
]
left_close_num
=
0
left_not_close_num
=
0
right_not_close_num
=
0
right_close_num
=
0
lines_text_list
=
[]
multiple_para_flag
=
False
last_line
=
block
[
'lines'
][
-
1
]
# 如果首行左边不顶格而右边顶格,末行左边顶格而右边不顶格 (第一行可能可以右边不顶格)
if
(
first_line
[
'bbox'
][
0
]
-
block
[
'bbox_fs'
][
0
]
>
line_height
/
2
and
# block['bbox_fs'][2] - first_line['bbox'][2] < line_height and
abs
(
last_line
[
'bbox'
][
0
]
-
block
[
'bbox_fs'
][
0
])
<
line_height
/
2
and
block
[
'bbox_fs'
][
2
]
-
last_line
[
'bbox'
][
2
]
>
line_height
):
multiple_para_flag
=
True
for
line
in
block
[
'lines'
]:
line_text
=
""
for
span
in
line
[
'spans'
]:
span_type
=
span
[
'type'
]
if
span_type
==
ContentType
.
Text
:
line_text
+=
span
[
'content'
].
strip
()
lines_text_list
.
append
(
line_text
)
# 计算line左侧顶格数量是否大于2,是否顶格用abs(block['bbox_fs'][0] - line['bbox'][0]) < line_height/2 来判断
if
abs
(
block
[
'bbox_fs'
][
0
]
-
line
[
'bbox'
][
0
])
<
line_height
/
2
:
left_close_num
+=
1
elif
line
[
'bbox'
][
0
]
-
block
[
'bbox_fs'
][
0
]
>
line_height
:
# logger.info(f"{line_text}, {block['bbox_fs']}, {line['bbox']}")
left_not_close_num
+=
1
# 计算右侧是否顶格
if
abs
(
block
[
'bbox_fs'
][
2
]
-
line
[
'bbox'
][
2
])
<
line_height
:
right_close_num
+=
1
else
:
# 右侧不顶格情况下是否有一段距离,拍脑袋用0.3block宽度做阈值
closed_area
=
0.3
*
block_weight
# closed_area = 5 * line_height
if
block
[
'bbox_fs'
][
2
]
-
line
[
'bbox'
][
2
]
>
closed_area
:
right_not_close_num
+=
1
# 判断lines_text_list中的元素是否有超过80%都以LIST_END_FLAG结尾
line_end_flag
=
False
# 判断lines_text_list中的元素是否有超过80%都以数字开头或都以数字结尾
line_num_flag
=
False
num_start_count
=
0
num_end_count
=
0
flag_end_count
=
0
if
len
(
lines_text_list
)
>
0
:
for
line_text
in
lines_text_list
:
if
len
(
line_text
)
>
0
:
if
line_text
[
-
1
]
in
LIST_END_FLAG
:
flag_end_count
+=
1
if
line_text
[
0
].
isdigit
():
num_start_count
+=
1
if
line_text
[
-
1
].
isdigit
():
num_end_count
+=
1
if
flag_end_count
/
len
(
lines_text_list
)
>=
0.8
:
line_end_flag
=
True
if
num_start_count
/
len
(
lines_text_list
)
>=
0.8
or
num_end_count
/
len
(
lines_text_list
)
>=
0.8
:
line_num_flag
=
True
# 有的目录右侧不贴边, 目前认为左边或者右边有一边全贴边,且符合数字规则极为index
if
((
left_close_num
/
len
(
block
[
'lines'
])
>=
0.8
or
right_close_num
/
len
(
block
[
'lines'
])
>=
0.8
)
and
line_num_flag
):
for
line
in
block
[
'lines'
]:
line
[
ListLineTag
.
IS_LIST_START_LINE
]
=
True
return
BlockType
.
Index
elif
left_close_num
>=
2
and
(
right_not_close_num
>=
2
or
line_end_flag
or
left_not_close_num
>=
2
)
and
not
multiple_para_flag
:
# 处理一种特殊的没有缩进的list,所有行都贴左边,通过右边的空隙判断是否是item尾
if
left_close_num
/
len
(
block
[
'lines'
])
>
0.9
:
# 这种是每个item只有一行,且左边都贴边的短item list
if
flag_end_count
==
0
and
right_close_num
/
len
(
block
[
'lines'
])
<
0.5
:
for
line
in
block
[
'lines'
]:
if
abs
(
block
[
'bbox_fs'
][
0
]
-
line
[
'bbox'
][
0
])
<
line_height
/
2
:
line
[
ListLineTag
.
IS_LIST_START_LINE
]
=
True
# 这种是大部分line item 都有结束标识符的情况,按结束标识符区分不同item
elif
line_end_flag
:
for
i
,
line
in
enumerate
(
block
[
'lines'
]):
if
lines_text_list
[
i
][
-
1
]
in
LIST_END_FLAG
:
line
[
ListLineTag
.
IS_LIST_END_LINE
]
=
True
if
i
+
1
<
len
(
block
[
'lines'
]):
block
[
'lines'
][
i
+
1
][
ListLineTag
.
IS_LIST_START_LINE
]
=
True
# line item基本没有结束标识符,而且也没有缩进,按右侧空隙判断哪些是item end
else
:
line_start_flag
=
False
for
i
,
line
in
enumerate
(
block
[
'lines'
]):
if
line_start_flag
:
line
[
ListLineTag
.
IS_LIST_START_LINE
]
=
True
line_start_flag
=
False
elif
abs
(
block
[
'bbox_fs'
][
2
]
-
line
[
'bbox'
][
2
])
>
line_height
:
line
[
ListLineTag
.
IS_LIST_END_LINE
]
=
True
line_start_flag
=
True
# 一种有缩进的特殊有序list,start line 左侧不贴边且以数字开头,end line 以 IS_LIST_END_LINE 结尾且数量和start line 一致
elif
num_start_count
>=
2
and
num_start_count
==
flag_end_count
:
# 简单一点先不考虑左侧不贴边的情况
for
i
,
line
in
enumerate
(
block
[
'lines'
]):
if
lines_text_list
[
i
][
0
].
isdigit
():
line
[
ListLineTag
.
IS_LIST_START_LINE
]
=
True
if
lines_text_list
[
i
][
-
1
]
in
LIST_END_FLAG
:
line
[
ListLineTag
.
IS_LIST_END_LINE
]
=
True
else
:
# 正常有缩进的list处理
for
line
in
block
[
'lines'
]:
if
abs
(
block
[
'bbox_fs'
][
0
]
-
line
[
'bbox'
][
0
])
<
line_height
/
2
:
line
[
ListLineTag
.
IS_LIST_START_LINE
]
=
True
if
abs
(
block
[
'bbox_fs'
][
2
]
-
line
[
'bbox'
][
2
])
>
line_height
:
line
[
ListLineTag
.
IS_LIST_END_LINE
]
=
True
return
BlockType
.
List
else
:
return
BlockType
.
Text
else
:
return
BlockType
.
Text
def
__merge_2_text_blocks
(
block1
,
block2
):
if
len
(
block1
[
'lines'
])
>
0
:
first_line
=
block1
[
'lines'
][
0
]
line_height
=
first_line
[
'bbox'
][
3
]
-
first_line
[
'bbox'
][
1
]
block1_weight
=
block1
[
'bbox'
][
2
]
-
block1
[
'bbox'
][
0
]
block2_weight
=
block2
[
'bbox'
][
2
]
-
block2
[
'bbox'
][
0
]
min_block_weight
=
min
(
block1_weight
,
block2_weight
)
if
abs
(
block1
[
'bbox_fs'
][
0
]
-
first_line
[
'bbox'
][
0
])
<
line_height
/
2
:
last_line
=
block2
[
'lines'
][
-
1
]
if
len
(
last_line
[
'spans'
])
>
0
:
last_span
=
last_line
[
'spans'
][
-
1
]
line_height
=
last_line
[
'bbox'
][
3
]
-
last_line
[
'bbox'
][
1
]
if
(
abs
(
block2
[
'bbox_fs'
][
2
]
-
last_line
[
'bbox'
][
2
])
<
line_height
and
not
last_span
[
'content'
].
endswith
(
LINE_STOP_FLAG
)
and
# 两个block宽度差距超过2倍也不合并
abs
(
block1_weight
-
block2_weight
)
<
min_block_weight
):
if
block1
[
'page_num'
]
!=
block2
[
'page_num'
]:
for
line
in
block1
[
'lines'
]:
for
span
in
line
[
'spans'
]:
span
[
CROSS_PAGE
]
=
True
block2
[
'lines'
].
extend
(
block1
[
'lines'
])
block1
[
'lines'
]
=
[]
block1
[
LINES_DELETED
]
=
True
return
block1
,
block2
def
__merge_2_list_blocks
(
block1
,
block2
):
if
block1
[
'page_num'
]
!=
block2
[
'page_num'
]:
for
line
in
block1
[
'lines'
]:
for
span
in
line
[
'spans'
]:
span
[
CROSS_PAGE
]
=
True
block2
[
'lines'
].
extend
(
block1
[
'lines'
])
block1
[
'lines'
]
=
[]
block1
[
LINES_DELETED
]
=
True
return
block1
,
block2
def
__is_list_group
(
text_blocks_group
):
# list group的特征是一个group内的所有block都满足以下条件
# 1.每个block都不超过3行 2. 每个block 的左边界都比较接近(逻辑简单点先不加这个规则)
for
block
in
text_blocks_group
:
if
len
(
block
[
'lines'
])
>
3
:
return
False
return
True
def
__para_merge_page
(
blocks
):
page_text_blocks_groups
=
__process_blocks
(
blocks
)
for
text_blocks_group
in
page_text_blocks_groups
:
if
len
(
text_blocks_group
)
>
0
:
# 需要先在合并前对所有block判断是否为list or index block
for
block
in
text_blocks_group
:
block_type
=
__is_list_or_index_block
(
block
)
block
[
'type'
]
=
block_type
# logger.info(f"{block['type']}:{block}")
if
len
(
text_blocks_group
)
>
1
:
# 在合并前判断这个group 是否是一个 list group
is_list_group
=
__is_list_group
(
text_blocks_group
)
# 倒序遍历
for
i
in
range
(
len
(
text_blocks_group
)
-
1
,
-
1
,
-
1
):
current_block
=
text_blocks_group
[
i
]
# 检查是否有前一个块
if
i
-
1
>=
0
:
prev_block
=
text_blocks_group
[
i
-
1
]
if
current_block
[
'type'
]
==
'text'
and
prev_block
[
'type'
]
==
'text'
and
not
is_list_group
:
__merge_2_text_blocks
(
current_block
,
prev_block
)
elif
(
(
current_block
[
'type'
]
==
BlockType
.
List
and
prev_block
[
'type'
]
==
BlockType
.
List
)
or
(
current_block
[
'type'
]
==
BlockType
.
Index
and
prev_block
[
'type'
]
==
BlockType
.
Index
)
):
__merge_2_list_blocks
(
current_block
,
prev_block
)
else
:
continue
def
para_split
(
pdf_info_dict
,
debug_mode
=
False
):
all_blocks
=
[]
for
page_num
,
page
in
pdf_info_dict
.
items
():
blocks
=
copy
.
deepcopy
(
page
[
'preproc_blocks'
])
for
block
in
blocks
:
block
[
'page_num'
]
=
page_num
all_blocks
.
extend
(
blocks
)
__para_merge_page
(
all_blocks
)
for
page_num
,
page
in
pdf_info_dict
.
items
():
page
[
'para_blocks'
]
=
[]
for
block
in
all_blocks
:
if
block
[
'page_num'
]
==
page_num
:
page
[
'para_blocks'
].
append
(
block
)
if
__name__
==
'__main__'
:
input_blocks
=
[]
# 调用函数
groups
=
__process_blocks
(
input_blocks
)
for
group_index
,
group
in
enumerate
(
groups
):
print
(
f
"Group
{
group_index
}
:
{
group
}
"
)
magic_pdf/pdf_parse_by_ocr.py
View file @
3a42ebbf
from
magic_pdf.pdf_parse_union_core
import
pdf_parse_union
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
from
magic_pdf.data.dataset
import
PymuDocDataset
from
magic_pdf.pdf_parse_union_core_v2
import
pdf_parse_union
def
parse_pdf_by_ocr
(
pdf_bytes
,
...
...
@@ -8,10 +10,11 @@ def parse_pdf_by_ocr(pdf_bytes,
end_page_id
=
None
,
debug_mode
=
False
,
):
return
pdf_parse_union
(
pdf_bytes
,
dataset
=
PymuDocDataset
(
pdf_bytes
)
return
pdf_parse_union
(
dataset
,
model_list
,
imageWriter
,
"ocr"
,
SupportedPdfParseMethod
.
OCR
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
debug_mode
=
debug_mode
,
...
...
magic_pdf/pdf_parse_by_txt.py
View file @
3a42ebbf
from
magic_pdf.pdf_parse_union_core
import
pdf_parse_union
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
from
magic_pdf.data.dataset
import
PymuDocDataset
from
magic_pdf.pdf_parse_union_core_v2
import
pdf_parse_union
def
parse_pdf_by_txt
(
...
...
@@ -9,10 +11,11 @@ def parse_pdf_by_txt(
end_page_id
=
None
,
debug_mode
=
False
,
):
return
pdf_parse_union
(
pdf_bytes
,
dataset
=
PymuDocDataset
(
pdf_bytes
)
return
pdf_parse_union
(
dataset
,
model_list
,
imageWriter
,
"txt"
,
SupportedPdfParseMethod
.
TXT
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
debug_mode
=
debug_mode
,
...
...
magic_pdf/pdf_parse_union_core_v2.py
0 → 100644
View file @
3a42ebbf
import
copy
import
os
import
statistics
import
time
from
typing
import
List
import
torch
from
loguru
import
logger
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
from
magic_pdf.data.dataset
import
Dataset
,
PageableData
from
magic_pdf.libs.boxbase
import
calculate_overlap_area_in_bbox1_area_ratio
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.libs.commons
import
fitz
,
get_delta_time
from
magic_pdf.libs.config_reader
import
get_local_layoutreader_model_dir
from
magic_pdf.libs.convert_utils
import
dict_to_list
from
magic_pdf.libs.drop_reason
import
DropReason
from
magic_pdf.libs.hash_utils
import
compute_md5
from
magic_pdf.libs.local_math
import
float_equal
from
magic_pdf.libs.ocr_content_type
import
ContentType
,
BlockType
from
magic_pdf.model.magic_model
import
MagicModel
from
magic_pdf.para.para_split_v3
import
para_split
from
magic_pdf.pre_proc.citationmarker_remove
import
remove_citation_marker
from
magic_pdf.pre_proc.construct_page_dict
import
\
ocr_construct_page_component_v2
from
magic_pdf.pre_proc.cut_image
import
ocr_cut_image_and_table
from
magic_pdf.pre_proc.equations_replace
import
(
combine_chars_to_pymudict
,
remove_chars_in_text_blocks
,
replace_equations_in_textblock
)
from
magic_pdf.pre_proc.ocr_detect_all_bboxes
import
\
ocr_prepare_bboxes_for_layout_split_v2
from
magic_pdf.pre_proc.ocr_dict_merge
import
(
fill_spans_in_blocks
,
fix_block_spans
,
fix_discarded_block
,
fix_block_spans_v2
)
from
magic_pdf.pre_proc.ocr_span_list_modify
import
(
get_qa_need_list_v2
,
remove_overlaps_low_confidence_spans
,
remove_overlaps_min_spans
)
from
magic_pdf.pre_proc.resolve_bbox_conflict
import
\
check_useful_block_horizontal_overlap
def
remove_horizontal_overlap_block_which_smaller
(
all_bboxes
):
useful_blocks
=
[]
for
bbox
in
all_bboxes
:
useful_blocks
.
append
({
'bbox'
:
bbox
[:
4
]})
is_useful_block_horz_overlap
,
smaller_bbox
,
bigger_bbox
=
(
check_useful_block_horizontal_overlap
(
useful_blocks
)
)
if
is_useful_block_horz_overlap
:
logger
.
warning
(
f
'skip this page, reason:
{
DropReason
.
USEFUL_BLOCK_HOR_OVERLAP
}
, smaller bbox is
{
smaller_bbox
}
, bigger bbox is
{
bigger_bbox
}
'
)
# noqa: E501
for
bbox
in
all_bboxes
.
copy
():
if
smaller_bbox
==
bbox
[:
4
]:
all_bboxes
.
remove
(
bbox
)
return
is_useful_block_horz_overlap
,
all_bboxes
def
__replace_STX_ETX
(
text_str
:
str
):
"""Replace
\u0002
and
\u0003
, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
Drawback: This issue is only observed in English text; it has not been found in Chinese text so far.
Args:
text_str (str): raw text
Returns:
_type_: replaced text
"""
# noqa: E501
if
text_str
:
s
=
text_str
.
replace
(
'
\u0002
'
,
"'"
)
s
=
s
.
replace
(
'
\u0003
'
,
"'"
)
return
s
return
text_str
def
txt_spans_extract
(
pdf_page
,
inline_equations
,
interline_equations
):
text_raw_blocks
=
pdf_page
.
get_text
(
'dict'
,
flags
=
fitz
.
TEXTFLAGS_TEXT
)[
'blocks'
]
char_level_text_blocks
=
pdf_page
.
get_text
(
'rawdict'
,
flags
=
fitz
.
TEXTFLAGS_TEXT
)[
'blocks'
]
text_blocks
=
combine_chars_to_pymudict
(
text_raw_blocks
,
char_level_text_blocks
)
text_blocks
=
replace_equations_in_textblock
(
text_blocks
,
inline_equations
,
interline_equations
)
text_blocks
=
remove_citation_marker
(
text_blocks
)
text_blocks
=
remove_chars_in_text_blocks
(
text_blocks
)
spans
=
[]
for
v
in
text_blocks
:
for
line
in
v
[
'lines'
]:
for
span
in
line
[
'spans'
]:
bbox
=
span
[
'bbox'
]
if
float_equal
(
bbox
[
0
],
bbox
[
2
])
or
float_equal
(
bbox
[
1
],
bbox
[
3
]):
continue
if
span
.
get
(
'type'
)
not
in
(
ContentType
.
InlineEquation
,
ContentType
.
InterlineEquation
,
):
spans
.
append
(
{
'bbox'
:
list
(
span
[
'bbox'
]),
'content'
:
__replace_STX_ETX
(
span
[
'text'
]),
'type'
:
ContentType
.
Text
,
'score'
:
1.0
,
}
)
return
spans
def
replace_text_span
(
pymu_spans
,
ocr_spans
):
return
list
(
filter
(
lambda
x
:
x
[
'type'
]
!=
ContentType
.
Text
,
ocr_spans
))
+
pymu_spans
def
model_init
(
model_name
:
str
):
from
transformers
import
LayoutLMv3ForTokenClassification
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
'cuda'
)
if
torch
.
cuda
.
is_bf16_supported
():
supports_bfloat16
=
True
else
:
supports_bfloat16
=
False
else
:
device
=
torch
.
device
(
'cpu'
)
supports_bfloat16
=
False
if
model_name
==
'layoutreader'
:
# 检测modelscope的缓存目录是否存在
layoutreader_model_dir
=
get_local_layoutreader_model_dir
()
if
os
.
path
.
exists
(
layoutreader_model_dir
):
model
=
LayoutLMv3ForTokenClassification
.
from_pretrained
(
layoutreader_model_dir
)
else
:
logger
.
warning
(
'local layoutreader model not exists, use online model from huggingface'
)
model
=
LayoutLMv3ForTokenClassification
.
from_pretrained
(
'hantian/layoutreader'
)
# 检查设备是否支持 bfloat16
if
supports_bfloat16
:
model
.
bfloat16
()
model
.
to
(
device
).
eval
()
else
:
logger
.
error
(
'model name not allow'
)
exit
(
1
)
return
model
class
ModelSingleton
:
_instance
=
None
_models
=
{}
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
().
__new__
(
cls
)
return
cls
.
_instance
def
get_model
(
self
,
model_name
:
str
):
if
model_name
not
in
self
.
_models
:
self
.
_models
[
model_name
]
=
model_init
(
model_name
=
model_name
)
return
self
.
_models
[
model_name
]
def
do_predict
(
boxes
:
List
[
List
[
int
]],
model
)
->
List
[
int
]:
from
magic_pdf.model.v3.helpers
import
(
boxes2inputs
,
parse_logits
,
prepare_inputs
)
inputs
=
boxes2inputs
(
boxes
)
inputs
=
prepare_inputs
(
inputs
,
model
)
logits
=
model
(
**
inputs
).
logits
.
cpu
().
squeeze
(
0
)
return
parse_logits
(
logits
,
len
(
boxes
))
def
cal_block_index
(
fix_blocks
,
sorted_bboxes
):
for
block
in
fix_blocks
:
line_index_list
=
[]
if
len
(
block
[
'lines'
])
==
0
:
block
[
'index'
]
=
sorted_bboxes
.
index
(
block
[
'bbox'
])
else
:
for
line
in
block
[
'lines'
]:
line
[
'index'
]
=
sorted_bboxes
.
index
(
line
[
'bbox'
])
line_index_list
.
append
(
line
[
'index'
])
median_value
=
statistics
.
median
(
line_index_list
)
block
[
'index'
]
=
median_value
# 删除图表body block中的虚拟line信息, 并用real_lines信息回填
if
block
[
'type'
]
in
[
BlockType
.
ImageBody
,
BlockType
.
TableBody
]:
block
[
'virtual_lines'
]
=
copy
.
deepcopy
(
block
[
'lines'
])
block
[
'lines'
]
=
copy
.
deepcopy
(
block
[
'real_lines'
])
del
block
[
'real_lines'
]
return
fix_blocks
def
insert_lines_into_block
(
block_bbox
,
line_height
,
page_w
,
page_h
):
# block_bbox是一个元组(x0, y0, x1, y1),其中(x0, y0)是左下角坐标,(x1, y1)是右上角坐标
x0
,
y0
,
x1
,
y1
=
block_bbox
block_height
=
y1
-
y0
block_weight
=
x1
-
x0
# 如果block高度小于n行正文,则直接返回block的bbox
if
line_height
*
3
<
block_height
:
if
(
block_height
>
page_h
*
0.25
and
page_w
*
0.5
>
block_weight
>
page_w
*
0.25
):
# 可能是双列结构,可以切细点
lines
=
int
(
block_height
/
line_height
)
+
1
else
:
# 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
if
block_weight
>
page_w
*
0.4
:
line_height
=
(
y1
-
y0
)
/
3
lines
=
3
elif
block_weight
>
page_w
*
0.25
:
# (可能是三列结构,也切细点)
lines
=
int
(
block_height
/
line_height
)
+
1
else
:
# 判断长宽比
if
block_height
/
block_weight
>
1.2
:
# 细长的不分
return
[[
x0
,
y0
,
x1
,
y1
]]
else
:
# 不细长的还是分成两行
line_height
=
(
y1
-
y0
)
/
2
lines
=
2
# 确定从哪个y位置开始绘制线条
current_y
=
y0
# 用于存储线条的位置信息[(x0, y), ...]
lines_positions
=
[]
for
i
in
range
(
lines
):
lines_positions
.
append
([
x0
,
current_y
,
x1
,
current_y
+
line_height
])
current_y
+=
line_height
return
lines_positions
else
:
return
[[
x0
,
y0
,
x1
,
y1
]]
def
sort_lines_by_model
(
fix_blocks
,
page_w
,
page_h
,
line_height
):
page_line_list
=
[]
for
block
in
fix_blocks
:
if
block
[
'type'
]
in
[
BlockType
.
Text
,
BlockType
.
Title
,
BlockType
.
InterlineEquation
,
BlockType
.
ImageCaption
,
BlockType
.
ImageFootnote
,
BlockType
.
TableCaption
,
BlockType
.
TableFootnote
]:
if
len
(
block
[
'lines'
])
==
0
:
bbox
=
block
[
'bbox'
]
lines
=
insert_lines_into_block
(
bbox
,
line_height
,
page_w
,
page_h
)
for
line
in
lines
:
block
[
'lines'
].
append
({
'bbox'
:
line
,
'spans'
:
[]})
page_line_list
.
extend
(
lines
)
else
:
for
line
in
block
[
'lines'
]:
bbox
=
line
[
'bbox'
]
page_line_list
.
append
(
bbox
)
elif
block
[
'type'
]
in
[
BlockType
.
ImageBody
,
BlockType
.
TableBody
]:
bbox
=
block
[
'bbox'
]
block
[
"real_lines"
]
=
copy
.
deepcopy
(
block
[
'lines'
])
lines
=
insert_lines_into_block
(
bbox
,
line_height
,
page_w
,
page_h
)
block
[
'lines'
]
=
[]
for
line
in
lines
:
block
[
'lines'
].
append
({
'bbox'
:
line
,
'spans'
:
[]})
page_line_list
.
extend
(
lines
)
# 使用layoutreader排序
x_scale
=
1000.0
/
page_w
y_scale
=
1000.0
/
page_h
boxes
=
[]
# logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}")
for
left
,
top
,
right
,
bottom
in
page_line_list
:
if
left
<
0
:
logger
.
warning
(
f
'left < 0, left:
{
left
}
, right:
{
right
}
, top:
{
top
}
, bottom:
{
bottom
}
, page_w:
{
page_w
}
, page_h:
{
page_h
}
'
)
# noqa: E501
left
=
0
if
right
>
page_w
:
logger
.
warning
(
f
'right > page_w, left:
{
left
}
, right:
{
right
}
, top:
{
top
}
, bottom:
{
bottom
}
, page_w:
{
page_w
}
, page_h:
{
page_h
}
'
)
# noqa: E501
right
=
page_w
if
top
<
0
:
logger
.
warning
(
f
'top < 0, left:
{
left
}
, right:
{
right
}
, top:
{
top
}
, bottom:
{
bottom
}
, page_w:
{
page_w
}
, page_h:
{
page_h
}
'
)
# noqa: E501
top
=
0
if
bottom
>
page_h
:
logger
.
warning
(
f
'bottom > page_h, left:
{
left
}
, right:
{
right
}
, top:
{
top
}
, bottom:
{
bottom
}
, page_w:
{
page_w
}
, page_h:
{
page_h
}
'
)
# noqa: E501
bottom
=
page_h
left
=
round
(
left
*
x_scale
)
top
=
round
(
top
*
y_scale
)
right
=
round
(
right
*
x_scale
)
bottom
=
round
(
bottom
*
y_scale
)
assert
(
1000
>=
right
>=
left
>=
0
and
1000
>=
bottom
>=
top
>=
0
),
f
'Invalid box. right:
{
right
}
, left:
{
left
}
, bottom:
{
bottom
}
, top:
{
top
}
'
# noqa: E126, E121
boxes
.
append
([
left
,
top
,
right
,
bottom
])
model_manager
=
ModelSingleton
()
model
=
model_manager
.
get_model
(
'layoutreader'
)
with
torch
.
no_grad
():
orders
=
do_predict
(
boxes
,
model
)
sorted_bboxes
=
[
page_line_list
[
i
]
for
i
in
orders
]
return
sorted_bboxes
def
get_line_height
(
blocks
):
page_line_height_list
=
[]
for
block
in
blocks
:
if
block
[
'type'
]
in
[
BlockType
.
Text
,
BlockType
.
Title
,
BlockType
.
ImageCaption
,
BlockType
.
ImageFootnote
,
BlockType
.
TableCaption
,
BlockType
.
TableFootnote
]:
for
line
in
block
[
'lines'
]:
bbox
=
line
[
'bbox'
]
page_line_height_list
.
append
(
int
(
bbox
[
3
]
-
bbox
[
1
]))
if
len
(
page_line_height_list
)
>
0
:
return
statistics
.
median
(
page_line_height_list
)
else
:
return
10
def
process_groups
(
groups
,
body_key
,
caption_key
,
footnote_key
):
body_blocks
=
[]
caption_blocks
=
[]
footnote_blocks
=
[]
for
i
,
group
in
enumerate
(
groups
):
group
[
body_key
][
'group_id'
]
=
i
body_blocks
.
append
(
group
[
body_key
])
for
caption_block
in
group
[
caption_key
]:
caption_block
[
'group_id'
]
=
i
caption_blocks
.
append
(
caption_block
)
for
footnote_block
in
group
[
footnote_key
]:
footnote_block
[
'group_id'
]
=
i
footnote_blocks
.
append
(
footnote_block
)
return
body_blocks
,
caption_blocks
,
footnote_blocks
def
process_block_list
(
blocks
,
body_type
,
block_type
):
indices
=
[
block
[
'index'
]
for
block
in
blocks
]
median_index
=
statistics
.
median
(
indices
)
body_bbox
=
next
((
block
[
'bbox'
]
for
block
in
blocks
if
block
.
get
(
'type'
)
==
body_type
),
[])
return
{
'type'
:
block_type
,
'bbox'
:
body_bbox
,
'blocks'
:
blocks
,
'index'
:
median_index
,
}
def
revert_group_blocks
(
blocks
):
image_groups
=
{}
table_groups
=
{}
new_blocks
=
[]
for
block
in
blocks
:
if
block
[
'type'
]
in
[
BlockType
.
ImageBody
,
BlockType
.
ImageCaption
,
BlockType
.
ImageFootnote
]:
group_id
=
block
[
'group_id'
]
if
group_id
not
in
image_groups
:
image_groups
[
group_id
]
=
[]
image_groups
[
group_id
].
append
(
block
)
elif
block
[
'type'
]
in
[
BlockType
.
TableBody
,
BlockType
.
TableCaption
,
BlockType
.
TableFootnote
]:
group_id
=
block
[
'group_id'
]
if
group_id
not
in
table_groups
:
table_groups
[
group_id
]
=
[]
table_groups
[
group_id
].
append
(
block
)
else
:
new_blocks
.
append
(
block
)
for
group_id
,
blocks
in
image_groups
.
items
():
new_blocks
.
append
(
process_block_list
(
blocks
,
BlockType
.
ImageBody
,
BlockType
.
Image
))
for
group_id
,
blocks
in
table_groups
.
items
():
new_blocks
.
append
(
process_block_list
(
blocks
,
BlockType
.
TableBody
,
BlockType
.
Table
))
return
new_blocks
def
remove_outside_spans
(
spans
,
all_bboxes
,
all_discarded_blocks
):
def
get_block_bboxes
(
blocks
,
block_type_list
):
return
[
block
[
0
:
4
]
for
block
in
blocks
if
block
[
7
]
in
block_type_list
]
image_bboxes
=
get_block_bboxes
(
all_bboxes
,
[
BlockType
.
ImageBody
])
table_bboxes
=
get_block_bboxes
(
all_bboxes
,
[
BlockType
.
TableBody
])
other_block_type
=
[]
for
block_type
in
BlockType
.
__dict__
.
values
():
if
not
isinstance
(
block_type
,
str
):
continue
if
block_type
not
in
[
BlockType
.
ImageBody
,
BlockType
.
TableBody
]:
other_block_type
.
append
(
block_type
)
other_block_bboxes
=
get_block_bboxes
(
all_bboxes
,
other_block_type
)
discarded_block_bboxes
=
get_block_bboxes
(
all_discarded_blocks
,
[
BlockType
.
Discarded
])
new_spans
=
[]
for
span
in
spans
:
span_bbox
=
span
[
'bbox'
]
span_type
=
span
[
'type'
]
if
any
(
calculate_overlap_area_in_bbox1_area_ratio
(
span_bbox
,
block_bbox
)
>
0.4
for
block_bbox
in
discarded_block_bboxes
):
new_spans
.
append
(
span
)
continue
if
span_type
==
ContentType
.
Image
:
if
any
(
calculate_overlap_area_in_bbox1_area_ratio
(
span_bbox
,
block_bbox
)
>
0.5
for
block_bbox
in
image_bboxes
):
new_spans
.
append
(
span
)
elif
span_type
==
ContentType
.
Table
:
if
any
(
calculate_overlap_area_in_bbox1_area_ratio
(
span_bbox
,
block_bbox
)
>
0.5
for
block_bbox
in
table_bboxes
):
new_spans
.
append
(
span
)
else
:
if
any
(
calculate_overlap_area_in_bbox1_area_ratio
(
span_bbox
,
block_bbox
)
>
0.5
for
block_bbox
in
other_block_bboxes
):
new_spans
.
append
(
span
)
return
new_spans
def
parse_page_core
(
page_doc
:
PageableData
,
magic_model
,
page_id
,
pdf_bytes_md5
,
imageWriter
,
parse_mode
):
need_drop
=
False
drop_reason
=
[]
"""从magic_model对象中获取后面会用到的区块信息"""
# img_blocks = magic_model.get_imgs(page_id)
# table_blocks = magic_model.get_tables(page_id)
img_groups
=
magic_model
.
get_imgs_v2
(
page_id
)
table_groups
=
magic_model
.
get_tables_v2
(
page_id
)
img_body_blocks
,
img_caption_blocks
,
img_footnote_blocks
=
process_groups
(
img_groups
,
'image_body'
,
'image_caption_list'
,
'image_footnote_list'
)
table_body_blocks
,
table_caption_blocks
,
table_footnote_blocks
=
process_groups
(
table_groups
,
'table_body'
,
'table_caption_list'
,
'table_footnote_list'
)
discarded_blocks
=
magic_model
.
get_discarded
(
page_id
)
text_blocks
=
magic_model
.
get_text_blocks
(
page_id
)
title_blocks
=
magic_model
.
get_title_blocks
(
page_id
)
inline_equations
,
interline_equations
,
interline_equation_blocks
=
(
magic_model
.
get_equations
(
page_id
)
)
page_w
,
page_h
=
magic_model
.
get_page_size
(
page_id
)
"""将所有区块的bbox整理到一起"""
# interline_equation_blocks参数不够准,后面切换到interline_equations上
interline_equation_blocks
=
[]
if
len
(
interline_equation_blocks
)
>
0
:
all_bboxes
,
all_discarded_blocks
=
ocr_prepare_bboxes_for_layout_split_v2
(
img_body_blocks
,
img_caption_blocks
,
img_footnote_blocks
,
table_body_blocks
,
table_caption_blocks
,
table_footnote_blocks
,
discarded_blocks
,
text_blocks
,
title_blocks
,
interline_equation_blocks
,
page_w
,
page_h
,
)
else
:
all_bboxes
,
all_discarded_blocks
=
ocr_prepare_bboxes_for_layout_split_v2
(
img_body_blocks
,
img_caption_blocks
,
img_footnote_blocks
,
table_body_blocks
,
table_caption_blocks
,
table_footnote_blocks
,
discarded_blocks
,
text_blocks
,
title_blocks
,
interline_equations
,
page_w
,
page_h
,
)
spans
=
magic_model
.
get_all_spans
(
page_id
)
"""根据parse_mode,构造spans"""
if
parse_mode
==
SupportedPdfParseMethod
.
TXT
:
"""ocr 中文本类的 span 用 pymu spans 替换!"""
pymu_spans
=
txt_spans_extract
(
page_doc
,
inline_equations
,
interline_equations
)
spans
=
replace_text_span
(
pymu_spans
,
spans
)
elif
parse_mode
==
SupportedPdfParseMethod
.
OCR
:
pass
else
:
raise
Exception
(
'parse_mode must be txt or ocr'
)
"""在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
"""顺便删除大水印并保留abandon的span"""
spans
=
remove_outside_spans
(
spans
,
all_bboxes
,
all_discarded_blocks
)
"""删除重叠spans中置信度较低的那些"""
spans
,
dropped_spans_by_confidence
=
remove_overlaps_low_confidence_spans
(
spans
)
"""删除重叠spans中较小的那些"""
spans
,
dropped_spans_by_span_overlap
=
remove_overlaps_min_spans
(
spans
)
"""对image和table截图"""
spans
=
ocr_cut_image_and_table
(
spans
,
page_doc
,
page_id
,
pdf_bytes_md5
,
imageWriter
)
"""先处理不需要排版的discarded_blocks"""
discarded_block_with_spans
,
spans
=
fill_spans_in_blocks
(
all_discarded_blocks
,
spans
,
0.4
)
fix_discarded_blocks
=
fix_discarded_block
(
discarded_block_with_spans
)
"""如果当前页面没有bbox则跳过"""
if
len
(
all_bboxes
)
==
0
:
logger
.
warning
(
f
'skip this page, not found useful bbox, page_id:
{
page_id
}
'
)
return
ocr_construct_page_component_v2
(
[],
[],
page_id
,
page_w
,
page_h
,
[],
[],
[],
interline_equations
,
fix_discarded_blocks
,
need_drop
,
drop_reason
,
)
"""将span填入blocks中"""
block_with_spans
,
spans
=
fill_spans_in_blocks
(
all_bboxes
,
spans
,
0.5
)
"""对block进行fix操作"""
fix_blocks
=
fix_block_spans_v2
(
block_with_spans
)
"""获取所有line并计算正文line的高度"""
line_height
=
get_line_height
(
fix_blocks
)
"""获取所有line并对line排序"""
sorted_bboxes
=
sort_lines_by_model
(
fix_blocks
,
page_w
,
page_h
,
line_height
)
"""根据line的中位数算block的序列关系"""
fix_blocks
=
cal_block_index
(
fix_blocks
,
sorted_bboxes
)
"""将image和table的block还原回group形式参与后续流程"""
fix_blocks
=
revert_group_blocks
(
fix_blocks
)
"""重排block"""
sorted_blocks
=
sorted
(
fix_blocks
,
key
=
lambda
b
:
b
[
'index'
])
"""获取QA需要外置的list"""
images
,
tables
,
interline_equations
=
get_qa_need_list_v2
(
sorted_blocks
)
"""构造pdf_info_dict"""
page_info
=
ocr_construct_page_component_v2
(
sorted_blocks
,
[],
page_id
,
page_w
,
page_h
,
[],
images
,
tables
,
interline_equations
,
fix_discarded_blocks
,
need_drop
,
drop_reason
,
)
return
page_info
def
pdf_parse_union
(
dataset
:
Dataset
,
model_list
,
imageWriter
,
parse_mode
,
start_page_id
=
0
,
end_page_id
=
None
,
debug_mode
=
False
,
):
pdf_bytes_md5
=
compute_md5
(
dataset
.
data_bits
())
"""初始化空的pdf_info_dict"""
pdf_info_dict
=
{}
"""用model_list和docs对象初始化magic_model"""
magic_model
=
MagicModel
(
model_list
,
dataset
)
"""根据输入的起始范围解析pdf"""
# end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
end_page_id
=
(
end_page_id
if
end_page_id
is
not
None
and
end_page_id
>=
0
else
len
(
dataset
)
-
1
)
if
end_page_id
>
len
(
dataset
)
-
1
:
logger
.
warning
(
'end_page_id is out of range, use pdf_docs length'
)
end_page_id
=
len
(
dataset
)
-
1
"""初始化启动时间"""
start_time
=
time
.
time
()
for
page_id
,
page
in
enumerate
(
dataset
):
"""debug时输出每页解析的耗时."""
if
debug_mode
:
time_now
=
time
.
time
()
logger
.
info
(
f
'page_id:
{
page_id
}
, last_page_cost_time:
{
get_delta_time
(
start_time
)
}
'
)
start_time
=
time_now
"""解析pdf中的每一页"""
if
start_page_id
<=
page_id
<=
end_page_id
:
page_info
=
parse_page_core
(
page
,
magic_model
,
page_id
,
pdf_bytes_md5
,
imageWriter
,
parse_mode
)
else
:
page_info
=
page
.
get_page_info
()
page_w
=
page_info
.
w
page_h
=
page_info
.
h
page_info
=
ocr_construct_page_component_v2
(
[],
[],
page_id
,
page_w
,
page_h
,
[],
[],
[],
[],
[],
True
,
'skip page'
)
pdf_info_dict
[
f
'page_
{
page_id
}
'
]
=
page_info
"""分段"""
para_split
(
pdf_info_dict
,
debug_mode
=
debug_mode
)
"""dict转list"""
pdf_info_list
=
dict_to_list
(
pdf_info_dict
)
new_pdf_info_dict
=
{
'pdf_info'
:
pdf_info_list
,
}
clean_memory
()
return
new_pdf_info_dict
if
__name__
==
'__main__'
:
pass
magic_pdf/pipe/AbsPipe.py
View file @
3a42ebbf
...
...
@@ -17,7 +17,7 @@ class AbsPipe(ABC):
PIP_TXT
=
"txt"
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
):
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
self
.
pdf_bytes
=
pdf_bytes
self
.
model_list
=
model_list
self
.
image_writer
=
image_writer
...
...
@@ -25,6 +25,10 @@ class AbsPipe(ABC):
self
.
is_debug
=
is_debug
self
.
start_page_id
=
start_page_id
self
.
end_page_id
=
end_page_id
self
.
lang
=
lang
self
.
layout_model
=
layout_model
self
.
formula_enable
=
formula_enable
self
.
table_enable
=
table_enable
def
get_compress_pdf_mid_data
(
self
):
return
JsonCompressor
.
compress_json
(
self
.
pdf_mid_data
)
...
...
magic_pdf/pipe/OCRPipe.py
View file @
3a42ebbf
...
...
@@ -10,19 +10,25 @@ from magic_pdf.user_api import parse_ocr_pdf
class
OCRPipe
(
AbsPipe
):
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
):
super
().
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
)
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
super
().
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
def
pipe_classify
(
self
):
pass
def
pipe_analyze
(
self
):
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
True
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
)
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_parse
(
self
):
self
.
pdf_mid_data
=
parse_ocr_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
)
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
result
=
super
().
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
...
...
magic_pdf/pipe/TXTPipe.py
View file @
3a42ebbf
...
...
@@ -11,19 +11,25 @@ from magic_pdf.user_api import parse_txt_pdf
class
TXTPipe
(
AbsPipe
):
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
):
super
().
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
)
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
super
().
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
def
pipe_classify
(
self
):
pass
def
pipe_analyze
(
self
):
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
False
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
)
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_parse
(
self
):
self
.
pdf_mid_data
=
parse_txt_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
)
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
result
=
super
().
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
...
...
magic_pdf/pipe/UNIPipe.py
View file @
3a42ebbf
...
...
@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
class
UNIPipe
(
AbsPipe
):
def
__init__
(
self
,
pdf_bytes
:
bytes
,
jso_useful_key
:
dict
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
):
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
self
.
pdf_type
=
jso_useful_key
[
"_pdf_type"
]
super
().
__init__
(
pdf_bytes
,
jso_useful_key
[
"model_list"
],
image_writer
,
is_debug
,
start_page_id
,
end_page_id
)
super
().
__init__
(
pdf_bytes
,
jso_useful_key
[
"model_list"
],
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
if
len
(
self
.
model_list
)
==
0
:
self
.
input_model_is_empty
=
True
else
:
...
...
@@ -28,22 +30,29 @@ class UNIPipe(AbsPipe):
def
pipe_analyze
(
self
):
if
self
.
pdf_type
==
self
.
PIP_TXT
:
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
False
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
)
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
elif
self
.
pdf_type
==
self
.
PIP_OCR
:
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
True
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
)
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_parse
(
self
):
if
self
.
pdf_type
==
self
.
PIP_TXT
:
self
.
pdf_mid_data
=
parse_union_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
input_model_is_empty
=
self
.
input_model_is_empty
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
)
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
elif
self
.
pdf_type
==
self
.
PIP_OCR
:
self
.
pdf_mid_data
=
parse_ocr_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
)
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
NONE_WITH_REASON
):
result
=
super
().
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
logger
.
info
(
"uni_pipe mk content list finished"
)
return
result
...
...
magic_pdf/pre_proc/ocr_detect_all_bboxes.py
View file @
3a42ebbf
from
loguru
import
logger
from
magic_pdf.libs.boxbase
import
get_minbox_if_overlap_by_ratio
,
calculate_overlap_area_in_bbox1_area_ratio
,
\
calculate_iou
calculate_iou
,
calculate_vertical_projection_overlap_ratio
from
magic_pdf.libs.drop_tag
import
DropTag
from
magic_pdf.libs.ocr_content_type
import
BlockType
from
magic_pdf.pre_proc.remove_bbox_overlap
import
remove_overlap_between_bbox_for_block
...
...
@@ -60,6 +60,88 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
return
all_bboxes
,
all_discarded_blocks
,
drop_reasons
def
add_bboxes
(
blocks
,
block_type
,
bboxes
):
for
block
in
blocks
:
x0
,
y0
,
x1
,
y1
=
block
[
'bbox'
]
if
block_type
in
[
BlockType
.
ImageBody
,
BlockType
.
ImageCaption
,
BlockType
.
ImageFootnote
,
BlockType
.
TableBody
,
BlockType
.
TableCaption
,
BlockType
.
TableFootnote
]:
bboxes
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
block_type
,
None
,
None
,
None
,
None
,
block
[
"score"
],
block
[
"group_id"
]])
else
:
bboxes
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
block_type
,
None
,
None
,
None
,
None
,
block
[
"score"
]])
def
ocr_prepare_bboxes_for_layout_split_v2
(
img_body_blocks
,
img_caption_blocks
,
img_footnote_blocks
,
table_body_blocks
,
table_caption_blocks
,
table_footnote_blocks
,
discarded_blocks
,
text_blocks
,
title_blocks
,
interline_equation_blocks
,
page_w
,
page_h
):
all_bboxes
=
[]
add_bboxes
(
img_body_blocks
,
BlockType
.
ImageBody
,
all_bboxes
)
add_bboxes
(
img_caption_blocks
,
BlockType
.
ImageCaption
,
all_bboxes
)
add_bboxes
(
img_footnote_blocks
,
BlockType
.
ImageFootnote
,
all_bboxes
)
add_bboxes
(
table_body_blocks
,
BlockType
.
TableBody
,
all_bboxes
)
add_bboxes
(
table_caption_blocks
,
BlockType
.
TableCaption
,
all_bboxes
)
add_bboxes
(
table_footnote_blocks
,
BlockType
.
TableFootnote
,
all_bboxes
)
add_bboxes
(
text_blocks
,
BlockType
.
Text
,
all_bboxes
)
add_bboxes
(
title_blocks
,
BlockType
.
Title
,
all_bboxes
)
add_bboxes
(
interline_equation_blocks
,
BlockType
.
InterlineEquation
,
all_bboxes
)
'''block嵌套问题解决'''
'''文本框与标题框重叠,优先信任文本框'''
all_bboxes
=
fix_text_overlap_title_blocks
(
all_bboxes
)
'''任何框体与舍弃框重叠,优先信任舍弃框'''
all_bboxes
=
remove_need_drop_blocks
(
all_bboxes
,
discarded_blocks
)
# interline_equation 与title或text框冲突的情况,分两种情况处理
'''interline_equation框与文本类型框iou比较接近1的时候,信任行间公式框'''
all_bboxes
=
fix_interline_equation_overlap_text_blocks_with_hi_iou
(
all_bboxes
)
'''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
# 通过后续大框套小框逻辑删除
'''discarded_blocks'''
all_discarded_blocks
=
[]
add_bboxes
(
discarded_blocks
,
BlockType
.
Discarded
,
all_discarded_blocks
)
'''footnote识别:宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的'''
footnote_blocks
=
[]
for
discarded
in
discarded_blocks
:
x0
,
y0
,
x1
,
y1
=
discarded
[
'bbox'
]
if
(
x1
-
x0
)
>
(
page_w
/
3
)
and
(
y1
-
y0
)
>
10
and
y0
>
(
page_h
/
2
):
footnote_blocks
.
append
([
x0
,
y0
,
x1
,
y1
])
'''移除在footnote下面的任何框'''
need_remove_blocks
=
find_blocks_under_footnote
(
all_bboxes
,
footnote_blocks
)
if
len
(
need_remove_blocks
)
>
0
:
for
block
in
need_remove_blocks
:
all_bboxes
.
remove
(
block
)
all_discarded_blocks
.
append
(
block
)
'''经过以上处理后,还存在大框套小框的情况,则删除小框'''
all_bboxes
=
remove_overlaps_min_blocks
(
all_bboxes
)
all_discarded_blocks
=
remove_overlaps_min_blocks
(
all_discarded_blocks
)
'''将剩余的bbox做分离处理,防止后面分layout时出错'''
all_bboxes
,
drop_reasons
=
remove_overlap_between_bbox_for_block
(
all_bboxes
)
return
all_bboxes
,
all_discarded_blocks
def
find_blocks_under_footnote
(
all_bboxes
,
footnote_blocks
):
need_remove_blocks
=
[]
for
block
in
all_bboxes
:
block_x0
,
block_y0
,
block_x1
,
block_y1
=
block
[:
4
]
for
footnote_bbox
in
footnote_blocks
:
footnote_x0
,
footnote_y0
,
footnote_x1
,
footnote_y1
=
footnote_bbox
# 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
if
block_y0
>=
footnote_y1
and
calculate_vertical_projection_overlap_ratio
((
block_x0
,
block_y0
,
block_x1
,
block_y1
),
footnote_bbox
)
>=
0.8
:
if
block
not
in
need_remove_blocks
:
need_remove_blocks
.
append
(
block
)
break
return
need_remove_blocks
def
fix_interline_equation_overlap_text_blocks_with_hi_iou
(
all_bboxes
):
# 先提取所有text和interline block
text_blocks
=
[]
...
...
magic_pdf/pre_proc/ocr_dict_merge.py
View file @
3a42ebbf
...
...
@@ -49,8 +49,7 @@ def merge_spans_to_line(spans):
continue
# 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
if
__is_overlaps_y_exceeds_threshold
(
span
[
'bbox'
],
current_line
[
-
1
][
'bbox'
]):
if
__is_overlaps_y_exceeds_threshold
(
span
[
'bbox'
],
current_line
[
-
1
][
'bbox'
],
0.5
):
current_line
.
append
(
span
)
else
:
# 否则,开始新行
...
...
@@ -154,6 +153,11 @@ def fill_spans_in_blocks(blocks, spans, radio):
'type'
:
block_type
,
'bbox'
:
block_bbox
,
}
if
block_type
in
[
BlockType
.
ImageBody
,
BlockType
.
ImageCaption
,
BlockType
.
ImageFootnote
,
BlockType
.
TableBody
,
BlockType
.
TableCaption
,
BlockType
.
TableFootnote
]:
block_dict
[
"group_id"
]
=
block
[
-
1
]
block_spans
=
[]
for
span
in
spans
:
span_bbox
=
span
[
'bbox'
]
...
...
@@ -202,6 +206,27 @@ def fix_block_spans(block_with_spans, img_blocks, table_blocks):
return
fix_blocks
def
fix_block_spans_v2
(
block_with_spans
):
"""1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系
需要将caption和footnote的text_span放入相应img_block和table_block内的
caption_block和footnote_block中 2、同时需要删除block中的spans字段."""
fix_blocks
=
[]
for
block
in
block_with_spans
:
block_type
=
block
[
'type'
]
if
block_type
in
[
BlockType
.
Text
,
BlockType
.
Title
,
BlockType
.
ImageCaption
,
BlockType
.
ImageFootnote
,
BlockType
.
TableCaption
,
BlockType
.
TableFootnote
]:
block
=
fix_text_block
(
block
)
elif
block_type
in
[
BlockType
.
InterlineEquation
,
BlockType
.
ImageBody
,
BlockType
.
TableBody
]:
block
=
fix_interline_block
(
block
)
else
:
continue
fix_blocks
.
append
(
block
)
return
fix_blocks
def
fix_discarded_block
(
discarded_block_with_spans
):
fix_discarded_blocks
=
[]
for
block
in
discarded_block_with_spans
:
...
...
magic_pdf/resources/model_config/UniMERNet/demo.yaml
View file @
3a42ebbf
...
...
@@ -2,13 +2,13 @@ model:
arch
:
unimernet
model_type
:
unimernet
model_config
:
model_name
:
./models
max_seq_len
:
1
024
length_aware
:
False
model_name
:
./models
/unimernet_base
max_seq_len
:
1
536
load_pretrained
:
True
pretrained
:
./models/pytorch_model.
bin
pretrained
:
'
./models/
unimernet_base/
pytorch_model.
pth'
tokenizer_config
:
path
:
./models
path
:
./models
/unimernet_base
datasets
:
formula_rec_eval
:
...
...
magic_pdf/resources/model_config/model_configs.yaml
View file @
3a42ebbf
config
:
device
:
cpu
layout
:
True
formula
:
True
table_config
:
model
:
TableMaster
is_table_recog_enable
:
False
max_time
:
400
weights
:
layout
:
Layout/model_final.pth
mfd
:
MFD/weights.pt
mfr
:
MFR/UniMERNet
layoutlmv3
:
Layout/LayoutLMv3/model_final.pth
doclayout_yolo
:
Layout/YOLO/doclayout_yolo_ft.pt
yolo_v8_mfd
:
MFD/YOLO/yolo_v8_ft.pt
unimernet_small
:
MFR/unimernet_small
struct_eqtable
:
TabRec/StructEqTable
TableMaster
:
TabRec/TableMaster
\ No newline at end of file
tablemaster
:
TabRec/TableMaster
\ No newline at end of file
magic_pdf/tools/cli.py
View file @
3a42ebbf
...
...
@@ -44,6 +44,18 @@ auto: automatically choose the best method for parsing pdf from ocr and txt.
without method specified, auto will be used by default."""
,
default
=
'auto'
,
)
@
click
.
option
(
'-l'
,
'--lang'
,
'lang'
,
type
=
str
,
help
=
"""
Input the languages in the pdf (if known) to improve OCR accuracy. Optional.
You should input "Abbreviation" with language form url:
https://paddlepaddle.github.io/PaddleOCR/latest/en/ppocr/blog/multi_languages.html#5-support-languages-and-abbreviations
"""
,
default
=
None
,
)
@
click
.
option
(
'-d'
,
'--debug'
,
...
...
@@ -68,7 +80,7 @@ without method specified, auto will be used by default.""",
help
=
'The ending page for PDF parsing, beginning from 0.'
,
default
=
None
,
)
def
cli
(
path
,
output_dir
,
method
,
debug_able
,
start_page_id
,
end_page_id
):
def
cli
(
path
,
output_dir
,
method
,
lang
,
debug_able
,
start_page_id
,
end_page_id
):
model_config
.
__use_inside_model__
=
True
model_config
.
__model_mode__
=
'full'
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
...
...
@@ -90,6 +102,7 @@ def cli(path, output_dir, method, debug_able, start_page_id, end_page_id):
debug_able
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
)
except
Exception
as
e
:
...
...
magic_pdf/tools/common.py
View file @
3a42ebbf
...
...
@@ -6,8 +6,8 @@ import click
from
loguru
import
logger
import
magic_pdf.model
as
model_config
from
magic_pdf.libs.draw_bbox
import
(
draw_layout_bbox
,
draw_
span
_bbox
,
dr
o
w_model_bbox
)
from
magic_pdf.libs.draw_bbox
import
(
draw_layout_bbox
,
draw_
line_sort
_bbox
,
dr
a
w_model_bbox
,
draw_span_bbox
)
from
magic_pdf.libs.MakeContentConfig
import
DropMode
,
MakeMode
from
magic_pdf.pipe.OCRPipe
import
OCRPipe
from
magic_pdf.pipe.TXTPipe
import
TXTPipe
...
...
@@ -39,16 +39,21 @@ def do_parse(
f_dump_middle_json
=
True
,
f_dump_model_json
=
True
,
f_dump_orig_pdf
=
True
,
f_dump_content_list
=
Fals
e
,
f_dump_content_list
=
Tru
e
,
f_make_md_mode
=
MakeMode
.
MM_MD
,
f_draw_model_bbox
=
False
,
f_draw_line_sort_bbox
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
):
if
debug_able
:
logger
.
warning
(
'debug mode is on'
)
f_dump_content_list
=
True
f_draw_model_bbox
=
True
f_draw_line_sort_bbox
=
True
orig_model_list
=
copy
.
deepcopy
(
model_list
)
local_image_dir
,
local_md_dir
=
prepare_env
(
output_dir
,
pdf_file_name
,
...
...
@@ -61,13 +66,16 @@ def do_parse(
if
parse_method
==
'auto'
:
jso_useful_key
=
{
'_pdf_type'
:
''
,
'model_list'
:
model_list
}
pipe
=
UNIPipe
(
pdf_bytes
,
jso_useful_key
,
image_writer
,
is_debug
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
)
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
elif
parse_method
==
'txt'
:
pipe
=
TXTPipe
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
)
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
elif
parse_method
==
'ocr'
:
pipe
=
OCRPipe
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
)
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
else
:
logger
.
error
(
'unknown parse method'
)
exit
(
1
)
...
...
@@ -89,7 +97,9 @@ def do_parse(
if
f_draw_span_bbox
:
draw_span_bbox
(
pdf_info
,
pdf_bytes
,
local_md_dir
,
pdf_file_name
)
if
f_draw_model_bbox
:
drow_model_bbox
(
copy
.
deepcopy
(
orig_model_list
),
pdf_bytes
,
local_md_dir
,
pdf_file_name
)
draw_model_bbox
(
copy
.
deepcopy
(
orig_model_list
),
pdf_bytes
,
local_md_dir
,
pdf_file_name
)
if
f_draw_line_sort_bbox
:
draw_line_sort_bbox
(
pdf_info
,
pdf_bytes
,
local_md_dir
,
pdf_file_name
)
md_content
=
pipe
.
pipe_mk_markdown
(
image_dir
,
drop_mode
=
DropMode
.
NONE
,
...
...
magic_pdf/user_api.py
View file @
3a42ebbf
...
...
@@ -26,7 +26,7 @@ PARSE_TYPE_OCR = "ocr"
def
parse_txt_pdf
(
pdf_bytes
:
bytes
,
pdf_models
:
list
,
imageWriter
:
AbsReaderWriter
,
is_debug
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
*
args
,
**
kwargs
):
"""
解析文本类pdf
...
...
@@ -44,11 +44,14 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
pdf_info_dict
[
"_version_name"
]
=
__version__
if
lang
is
not
None
:
pdf_info_dict
[
"_lang"
]
=
lang
return
pdf_info_dict
def
parse_ocr_pdf
(
pdf_bytes
:
bytes
,
pdf_models
:
list
,
imageWriter
:
AbsReaderWriter
,
is_debug
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
*
args
,
**
kwargs
):
"""
解析ocr类pdf
...
...
@@ -66,12 +69,15 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
pdf_info_dict
[
"_version_name"
]
=
__version__
if
lang
is
not
None
:
pdf_info_dict
[
"_lang"
]
=
lang
return
pdf_info_dict
def
parse_union_pdf
(
pdf_bytes
:
bytes
,
pdf_models
:
list
,
imageWriter
:
AbsReaderWriter
,
is_debug
=
False
,
input_model_is_empty
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
*
args
,
**
kwargs
):
"""
ocr和文本混合的pdf,全部解析出来
...
...
@@ -95,9 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
if
pdf_info_dict
is
None
or
pdf_info_dict
.
get
(
"_need_drop"
,
False
):
logger
.
warning
(
f
"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr"
)
if
input_model_is_empty
:
pdf_models
=
doc_analyze
(
pdf_bytes
,
ocr
=
True
,
layout_model
=
kwargs
.
get
(
"layout_model"
,
None
)
formula_enable
=
kwargs
.
get
(
"formula_enable"
,
None
)
table_enable
=
kwargs
.
get
(
"table_enable"
,
None
)
pdf_models
=
doc_analyze
(
pdf_bytes
,
ocr
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
)
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
,
)
pdf_info_dict
=
parse_pdf
(
parse_pdf_by_ocr
)
if
pdf_info_dict
is
None
:
raise
Exception
(
"Both parse_pdf_by_txt and parse_pdf_by_ocr failed."
)
...
...
@@ -108,4 +124,7 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
pdf_info_dict
[
"_version_name"
]
=
__version__
if
lang
is
not
None
:
pdf_info_dict
[
"_lang"
]
=
lang
return
pdf_info_dict
Prev
1
2
3
4
5
6
7
8
…
30
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment