Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
d13f3c6d
Commit
d13f3c6d
authored
Jan 06, 2025
by
icecraft
Browse files
refactor: remove unused method in MagicModel class
parent
ad9abc32
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
5 additions
and
1231 deletions
+5
-1231
magic_pdf/model/magic_model.py
magic_pdf/model/magic_model.py
+4
-435
magic_pdf/pdf_parse_by_ocr.py
magic_pdf/pdf_parse_by_ocr.py
+0
-22
magic_pdf/pdf_parse_by_txt.py
magic_pdf/pdf_parse_by_txt.py
+0
-23
magic_pdf/pipe/AbsPipe.py
magic_pdf/pipe/AbsPipe.py
+0
-99
magic_pdf/pipe/OCRPipe.py
magic_pdf/pipe/OCRPipe.py
+0
-80
magic_pdf/pipe/TXTPipe.py
magic_pdf/pipe/TXTPipe.py
+0
-42
magic_pdf/pipe/UNIPipe.py
magic_pdf/pipe/UNIPipe.py
+0
-150
magic_pdf/pipe/__init__.py
magic_pdf/pipe/__init__.py
+0
-0
magic_pdf/rw/AbsReaderWriter.py
magic_pdf/rw/AbsReaderWriter.py
+0
-17
magic_pdf/rw/DiskReaderWriter.py
magic_pdf/rw/DiskReaderWriter.py
+0
-74
magic_pdf/rw/S3ReaderWriter.py
magic_pdf/rw/S3ReaderWriter.py
+0
-142
magic_pdf/rw/__init__.py
magic_pdf/rw/__init__.py
+0
-0
magic_pdf/user_api.py
magic_pdf/user_api.py
+0
-144
tests/test_cli/test_bench_gpu.py
tests/test_cli/test_bench_gpu.py
+1
-3
No files found.
magic_pdf/model/magic_model.py
View file @
d13f3c6d
...
@@ -3,12 +3,9 @@ import enum
...
@@ -3,12 +3,9 @@ import enum
from
magic_pdf.config.model_block_type
import
ModelBlockTypeEnum
from
magic_pdf.config.model_block_type
import
ModelBlockTypeEnum
from
magic_pdf.config.ocr_content_type
import
CategoryId
,
ContentType
from
magic_pdf.config.ocr_content_type
import
CategoryId
,
ContentType
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.boxbase
import
(
_is_in
,
_is_part_overlap
,
bbox_distance
,
from
magic_pdf.libs.boxbase
import
(
_is_in
,
bbox_distance
,
bbox_relative_pos
,
bbox_relative_pos
,
box_area
,
calculate_iou
,
calculate_iou
)
calculate_overlap_area_in_bbox1_area_ratio
,
get_overlap_area
)
from
magic_pdf.libs.coordinate_transform
import
get_scale_ratio
from
magic_pdf.libs.coordinate_transform
import
get_scale_ratio
from
magic_pdf.libs.local_math
import
float_gt
from
magic_pdf.pre_proc.remove_bbox_overlap
import
_remove_overlap_between_bbox
from
magic_pdf.pre_proc.remove_bbox_overlap
import
_remove_overlap_between_bbox
CAPATION_OVERLAP_AREA_RATIO
=
0.6
CAPATION_OVERLAP_AREA_RATIO
=
0.6
...
@@ -208,393 +205,6 @@ class MagicModel:
...
@@ -208,393 +205,6 @@ class MagicModel:
keep
[
i
]
=
False
keep
[
i
]
=
False
return
[
bboxes
[
i
]
for
i
in
range
(
N
)
if
keep
[
i
]]
return
[
bboxes
[
i
]
for
i
in
range
(
N
)
if
keep
[
i
]]
def
__tie_up_category_by_distance
(
self
,
page_no
,
subject_category_id
,
object_category_id
):
"""假定每个 subject 最多有一个 object (可以有多个相邻的 object 合并为单个 object),每个 object
只能属于一个 subject."""
ret
=
[]
MAX_DIS_OF_POINT
=
10
**
9
+
7
"""
subject 和 object 的 bbox 会合并成一个大的 bbox (named: merged bbox)。
筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。
再求出筛选出的 subjects 和 object 的最短距离
"""
def
search_overlap_between_boxes
(
subject_idx
,
object_idx
):
idxes
=
[
subject_idx
,
object_idx
]
x0s
=
[
all_bboxes
[
idx
][
'bbox'
][
0
]
for
idx
in
idxes
]
y0s
=
[
all_bboxes
[
idx
][
'bbox'
][
1
]
for
idx
in
idxes
]
x1s
=
[
all_bboxes
[
idx
][
'bbox'
][
2
]
for
idx
in
idxes
]
y1s
=
[
all_bboxes
[
idx
][
'bbox'
][
3
]
for
idx
in
idxes
]
merged_bbox
=
[
min
(
x0s
),
min
(
y0s
),
max
(
x1s
),
max
(
y1s
),
]
ratio
=
0
other_objects
=
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
not
in
(
object_category_id
,
subject_category_id
),
self
.
__model_list
[
page_no
][
'layout_dets'
],
),
)
)
for
other_object
in
other_objects
:
ratio
=
max
(
ratio
,
get_overlap_area
(
merged_bbox
,
other_object
[
'bbox'
])
*
1.0
/
box_area
(
all_bboxes
[
object_idx
][
'bbox'
]),
)
if
ratio
>=
MERGE_BOX_OVERLAP_AREA_RATIO
:
break
return
ratio
def
may_find_other_nearest_bbox
(
subject_idx
,
object_idx
):
ret
=
float
(
'inf'
)
x0
=
min
(
all_bboxes
[
subject_idx
][
'bbox'
][
0
],
all_bboxes
[
object_idx
][
'bbox'
][
0
]
)
y0
=
min
(
all_bboxes
[
subject_idx
][
'bbox'
][
1
],
all_bboxes
[
object_idx
][
'bbox'
][
1
]
)
x1
=
max
(
all_bboxes
[
subject_idx
][
'bbox'
][
2
],
all_bboxes
[
object_idx
][
'bbox'
][
2
]
)
y1
=
max
(
all_bboxes
[
subject_idx
][
'bbox'
][
3
],
all_bboxes
[
object_idx
][
'bbox'
][
3
]
)
object_area
=
abs
(
all_bboxes
[
object_idx
][
'bbox'
][
2
]
-
all_bboxes
[
object_idx
][
'bbox'
][
0
]
)
*
abs
(
all_bboxes
[
object_idx
][
'bbox'
][
3
]
-
all_bboxes
[
object_idx
][
'bbox'
][
1
]
)
for
i
in
range
(
len
(
all_bboxes
)):
if
(
i
==
subject_idx
or
all_bboxes
[
i
][
'category_id'
]
!=
subject_category_id
):
continue
if
_is_part_overlap
([
x0
,
y0
,
x1
,
y1
],
all_bboxes
[
i
][
'bbox'
])
or
_is_in
(
all_bboxes
[
i
][
'bbox'
],
[
x0
,
y0
,
x1
,
y1
]
):
i_area
=
abs
(
all_bboxes
[
i
][
'bbox'
][
2
]
-
all_bboxes
[
i
][
'bbox'
][
0
]
)
*
abs
(
all_bboxes
[
i
][
'bbox'
][
3
]
-
all_bboxes
[
i
][
'bbox'
][
1
])
if
i_area
>=
object_area
:
ret
=
min
(
float
(
'inf'
),
dis
[
i
][
object_idx
])
return
ret
def
expand_bbbox
(
idxes
):
x0s
=
[
all_bboxes
[
idx
][
'bbox'
][
0
]
for
idx
in
idxes
]
y0s
=
[
all_bboxes
[
idx
][
'bbox'
][
1
]
for
idx
in
idxes
]
x1s
=
[
all_bboxes
[
idx
][
'bbox'
][
2
]
for
idx
in
idxes
]
y1s
=
[
all_bboxes
[
idx
][
'bbox'
][
3
]
for
idx
in
idxes
]
return
min
(
x0s
),
min
(
y0s
),
max
(
x1s
),
max
(
y1s
)
subjects
=
self
.
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
subject_category_id
,
self
.
__model_list
[
page_no
][
'layout_dets'
],
),
)
)
)
objects
=
self
.
__reduct_overlap
(
list
(
map
(
lambda
x
:
{
'bbox'
:
x
[
'bbox'
],
'score'
:
x
[
'score'
]},
filter
(
lambda
x
:
x
[
'category_id'
]
==
object_category_id
,
self
.
__model_list
[
page_no
][
'layout_dets'
],
),
)
)
)
subject_object_relation_map
=
{}
subjects
.
sort
(
key
=
lambda
x
:
x
[
'bbox'
][
0
]
**
2
+
x
[
'bbox'
][
1
]
**
2
)
# get the distance !
all_bboxes
=
[]
for
v
in
subjects
:
all_bboxes
.
append
(
{
'category_id'
:
subject_category_id
,
'bbox'
:
v
[
'bbox'
],
'score'
:
v
[
'score'
],
}
)
for
v
in
objects
:
all_bboxes
.
append
(
{
'category_id'
:
object_category_id
,
'bbox'
:
v
[
'bbox'
],
'score'
:
v
[
'score'
],
}
)
N
=
len
(
all_bboxes
)
dis
=
[[
MAX_DIS_OF_POINT
]
*
N
for
_
in
range
(
N
)]
for
i
in
range
(
N
):
for
j
in
range
(
i
):
if
(
all_bboxes
[
i
][
'category_id'
]
==
subject_category_id
and
all_bboxes
[
j
][
'category_id'
]
==
subject_category_id
):
continue
subject_idx
,
object_idx
=
i
,
j
if
all_bboxes
[
j
][
'category_id'
]
==
subject_category_id
:
subject_idx
,
object_idx
=
j
,
i
if
(
search_overlap_between_boxes
(
subject_idx
,
object_idx
)
>=
MERGE_BOX_OVERLAP_AREA_RATIO
):
dis
[
i
][
j
]
=
float
(
'inf'
)
dis
[
j
][
i
]
=
dis
[
i
][
j
]
continue
dis
[
i
][
j
]
=
self
.
_bbox_distance
(
all_bboxes
[
subject_idx
][
'bbox'
],
all_bboxes
[
object_idx
][
'bbox'
]
)
dis
[
j
][
i
]
=
dis
[
i
][
j
]
used
=
set
()
for
i
in
range
(
N
):
# 求第 i 个 subject 所关联的 object
if
all_bboxes
[
i
][
'category_id'
]
!=
subject_category_id
:
continue
seen
=
set
()
candidates
=
[]
arr
=
[]
for
j
in
range
(
N
):
pos_flag_count
=
sum
(
list
(
map
(
lambda
x
:
1
if
x
else
0
,
bbox_relative_pos
(
all_bboxes
[
i
][
'bbox'
],
all_bboxes
[
j
][
'bbox'
]
),
)
)
)
if
pos_flag_count
>
1
:
continue
if
(
all_bboxes
[
j
][
'category_id'
]
!=
object_category_id
or
j
in
used
or
dis
[
i
][
j
]
==
MAX_DIS_OF_POINT
):
continue
left
,
right
,
_
,
_
=
bbox_relative_pos
(
all_bboxes
[
i
][
'bbox'
],
all_bboxes
[
j
][
'bbox'
]
)
# 由 pos_flag_count 相关逻辑保证本段逻辑准确性
if
left
or
right
:
one_way_dis
=
all_bboxes
[
i
][
'bbox'
][
2
]
-
all_bboxes
[
i
][
'bbox'
][
0
]
else
:
one_way_dis
=
all_bboxes
[
i
][
'bbox'
][
3
]
-
all_bboxes
[
i
][
'bbox'
][
1
]
if
dis
[
i
][
j
]
>
one_way_dis
:
continue
arr
.
append
((
dis
[
i
][
j
],
j
))
arr
.
sort
(
key
=
lambda
x
:
x
[
0
])
if
len
(
arr
)
>
0
:
"""
bug: 离该subject 最近的 object 可能跨越了其它的 subject。
比如 [this subect] [some sbuject] [the nearest object of subject]
"""
if
may_find_other_nearest_bbox
(
i
,
arr
[
0
][
1
])
>=
arr
[
0
][
0
]:
candidates
.
append
(
arr
[
0
][
1
])
seen
.
add
(
arr
[
0
][
1
])
# 已经获取初始种子
for
j
in
set
(
candidates
):
tmp
=
[]
for
k
in
range
(
i
+
1
,
N
):
pos_flag_count
=
sum
(
list
(
map
(
lambda
x
:
1
if
x
else
0
,
bbox_relative_pos
(
all_bboxes
[
j
][
'bbox'
],
all_bboxes
[
k
][
'bbox'
]
),
)
)
)
if
pos_flag_count
>
1
:
continue
if
(
all_bboxes
[
k
][
'category_id'
]
!=
object_category_id
or
k
in
used
or
k
in
seen
or
dis
[
j
][
k
]
==
MAX_DIS_OF_POINT
or
dis
[
j
][
k
]
>
dis
[
i
][
j
]
):
continue
is_nearest
=
True
for
ni
in
range
(
i
+
1
,
N
):
if
ni
in
(
j
,
k
)
or
ni
in
used
or
ni
in
seen
:
continue
if
not
float_gt
(
dis
[
ni
][
k
],
dis
[
j
][
k
]):
is_nearest
=
False
break
if
is_nearest
:
nx0
,
ny0
,
nx1
,
ny1
=
expand_bbbox
(
list
(
seen
)
+
[
k
])
n_dis
=
bbox_distance
(
all_bboxes
[
i
][
'bbox'
],
[
nx0
,
ny0
,
nx1
,
ny1
]
)
if
float_gt
(
dis
[
i
][
j
],
n_dis
):
continue
tmp
.
append
(
k
)
seen
.
add
(
k
)
candidates
=
tmp
if
len
(
candidates
)
==
0
:
break
# 已经获取到某个 figure 下所有的最靠近的 captions,以及最靠近这些 captions 的 captions 。
# 先扩一下 bbox,
ox0
,
oy0
,
ox1
,
oy1
=
expand_bbbox
(
list
(
seen
)
+
[
i
])
ix0
,
iy0
,
ix1
,
iy1
=
all_bboxes
[
i
][
'bbox'
]
# 分成了 4 个截取空间,需要计算落在每个截取空间下 objects 合并后占据的矩形面积
caption_poses
=
[
[
ox0
,
oy0
,
ix0
,
oy1
],
[
ox0
,
oy0
,
ox1
,
iy0
],
[
ox0
,
iy1
,
ox1
,
oy1
],
[
ix1
,
oy0
,
ox1
,
oy1
],
]
caption_areas
=
[]
for
bbox
in
caption_poses
:
embed_arr
=
[]
for
idx
in
seen
:
if
(
calculate_overlap_area_in_bbox1_area_ratio
(
all_bboxes
[
idx
][
'bbox'
],
bbox
)
>
CAPATION_OVERLAP_AREA_RATIO
):
embed_arr
.
append
(
idx
)
if
len
(
embed_arr
)
>
0
:
embed_x0
=
min
([
all_bboxes
[
idx
][
'bbox'
][
0
]
for
idx
in
embed_arr
])
embed_y0
=
min
([
all_bboxes
[
idx
][
'bbox'
][
1
]
for
idx
in
embed_arr
])
embed_x1
=
max
([
all_bboxes
[
idx
][
'bbox'
][
2
]
for
idx
in
embed_arr
])
embed_y1
=
max
([
all_bboxes
[
idx
][
'bbox'
][
3
]
for
idx
in
embed_arr
])
caption_areas
.
append
(
int
(
abs
(
embed_x1
-
embed_x0
)
*
abs
(
embed_y1
-
embed_y0
))
)
else
:
caption_areas
.
append
(
0
)
subject_object_relation_map
[
i
]
=
[]
if
max
(
caption_areas
)
>
0
:
max_area_idx
=
caption_areas
.
index
(
max
(
caption_areas
))
caption_bbox
=
caption_poses
[
max_area_idx
]
for
j
in
seen
:
if
(
calculate_overlap_area_in_bbox1_area_ratio
(
all_bboxes
[
j
][
'bbox'
],
caption_bbox
)
>
CAPATION_OVERLAP_AREA_RATIO
):
used
.
add
(
j
)
subject_object_relation_map
[
i
].
append
(
j
)
for
i
in
sorted
(
subject_object_relation_map
.
keys
()):
result
=
{
'subject_body'
:
all_bboxes
[
i
][
'bbox'
],
'all'
:
all_bboxes
[
i
][
'bbox'
],
'score'
:
all_bboxes
[
i
][
'score'
],
}
if
len
(
subject_object_relation_map
[
i
])
>
0
:
x0
=
min
(
[
all_bboxes
[
j
][
'bbox'
][
0
]
for
j
in
subject_object_relation_map
[
i
]]
)
y0
=
min
(
[
all_bboxes
[
j
][
'bbox'
][
1
]
for
j
in
subject_object_relation_map
[
i
]]
)
x1
=
max
(
[
all_bboxes
[
j
][
'bbox'
][
2
]
for
j
in
subject_object_relation_map
[
i
]]
)
y1
=
max
(
[
all_bboxes
[
j
][
'bbox'
][
3
]
for
j
in
subject_object_relation_map
[
i
]]
)
result
[
'object_body'
]
=
[
x0
,
y0
,
x1
,
y1
]
result
[
'all'
]
=
[
min
(
x0
,
all_bboxes
[
i
][
'bbox'
][
0
]),
min
(
y0
,
all_bboxes
[
i
][
'bbox'
][
1
]),
max
(
x1
,
all_bboxes
[
i
][
'bbox'
][
2
]),
max
(
y1
,
all_bboxes
[
i
][
'bbox'
][
3
]),
]
ret
.
append
(
result
)
total_subject_object_dis
=
0
# 计算已经配对的 distance 距离
for
i
in
subject_object_relation_map
.
keys
():
for
j
in
subject_object_relation_map
[
i
]:
total_subject_object_dis
+=
bbox_distance
(
all_bboxes
[
i
][
'bbox'
],
all_bboxes
[
j
][
'bbox'
]
)
# 计算未匹配的 subject 和 object 的距离(非精确版)
with_caption_subject
=
set
(
[
key
for
key
in
subject_object_relation_map
.
keys
()
if
len
(
subject_object_relation_map
[
i
])
>
0
]
)
for
i
in
range
(
N
):
if
all_bboxes
[
i
][
'category_id'
]
!=
object_category_id
or
i
in
used
:
continue
candidates
=
[]
for
j
in
range
(
N
):
if
(
all_bboxes
[
j
][
'category_id'
]
!=
subject_category_id
or
j
in
with_caption_subject
):
continue
candidates
.
append
((
dis
[
i
][
j
],
j
))
if
len
(
candidates
)
>
0
:
candidates
.
sort
(
key
=
lambda
x
:
x
[
0
])
total_subject_object_dis
+=
candidates
[
0
][
1
]
with_caption_subject
.
add
(
j
)
return
ret
,
total_subject_object_dis
def
__tie_up_category_by_distance_v2
(
def
__tie_up_category_by_distance_v2
(
self
,
self
,
page_no
:
int
,
page_no
:
int
,
...
@@ -879,52 +489,12 @@ class MagicModel:
...
@@ -879,52 +489,12 @@ class MagicModel:
return
ret
return
ret
def
get_imgs
(
self
,
page_no
:
int
):
def
get_imgs
(
self
,
page_no
:
int
):
with_captions
,
_
=
self
.
__tie_up_category_by_distance
(
page_no
,
3
,
4
)
return
self
.
get_imgs_v2
(
page_no
)
with_footnotes
,
_
=
self
.
__tie_up_category_by_distance
(
page_no
,
3
,
CategoryId
.
ImageFootnote
)
ret
=
[]
N
,
M
=
len
(
with_captions
),
len
(
with_footnotes
)
assert
N
==
M
for
i
in
range
(
N
):
record
=
{
'score'
:
with_captions
[
i
][
'score'
],
'img_caption_bbox'
:
with_captions
[
i
].
get
(
'object_body'
,
None
),
'img_body_bbox'
:
with_captions
[
i
][
'subject_body'
],
'img_footnote_bbox'
:
with_footnotes
[
i
].
get
(
'object_body'
,
None
),
}
x0
=
min
(
with_captions
[
i
][
'all'
][
0
],
with_footnotes
[
i
][
'all'
][
0
])
y0
=
min
(
with_captions
[
i
][
'all'
][
1
],
with_footnotes
[
i
][
'all'
][
1
])
x1
=
max
(
with_captions
[
i
][
'all'
][
2
],
with_footnotes
[
i
][
'all'
][
2
])
y1
=
max
(
with_captions
[
i
][
'all'
][
3
],
with_footnotes
[
i
][
'all'
][
3
])
record
[
'bbox'
]
=
[
x0
,
y0
,
x1
,
y1
]
ret
.
append
(
record
)
return
ret
def
get_tables
(
def
get_tables
(
self
,
page_no
:
int
self
,
page_no
:
int
)
->
list
:
# 3个坐标, caption, table主体,table-note
)
->
list
:
# 3个坐标, caption, table主体,table-note
with_captions
,
_
=
self
.
__tie_up_category_by_distance
(
page_no
,
5
,
6
)
return
self
.
get_tables_v2
(
page_no
)
with_footnotes
,
_
=
self
.
__tie_up_category_by_distance
(
page_no
,
5
,
7
)
ret
=
[]
N
,
M
=
len
(
with_captions
),
len
(
with_footnotes
)
assert
N
==
M
for
i
in
range
(
N
):
record
=
{
'score'
:
with_captions
[
i
][
'score'
],
'table_caption_bbox'
:
with_captions
[
i
].
get
(
'object_body'
,
None
),
'table_body_bbox'
:
with_captions
[
i
][
'subject_body'
],
'table_footnote_bbox'
:
with_footnotes
[
i
].
get
(
'object_body'
,
None
),
}
x0
=
min
(
with_captions
[
i
][
'all'
][
0
],
with_footnotes
[
i
][
'all'
][
0
])
y0
=
min
(
with_captions
[
i
][
'all'
][
1
],
with_footnotes
[
i
][
'all'
][
1
])
x1
=
max
(
with_captions
[
i
][
'all'
][
2
],
with_footnotes
[
i
][
'all'
][
2
])
y1
=
max
(
with_captions
[
i
][
'all'
][
3
],
with_footnotes
[
i
][
'all'
][
3
])
record
[
'bbox'
]
=
[
x0
,
y0
,
x1
,
y1
]
ret
.
append
(
record
)
return
ret
def
get_equations
(
self
,
page_no
:
int
)
->
list
:
# 有坐标,也有字
def
get_equations
(
self
,
page_no
:
int
)
->
list
:
# 有坐标,也有字
inline_equations
=
self
.
__get_blocks_by_type
(
inline_equations
=
self
.
__get_blocks_by_type
(
...
@@ -1043,4 +613,3 @@ class MagicModel:
...
@@ -1043,4 +613,3 @@ class MagicModel:
def
get_model_list
(
self
,
page_no
):
def
get_model_list
(
self
,
page_no
):
return
self
.
__model_list
[
page_no
]
return
self
.
__model_list
[
page_no
]
magic_pdf/pdf_parse_by_ocr.py
deleted
100644 → 0
View file @
ad9abc32
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.pdf_parse_union_core_v2
import
pdf_parse_union
def
parse_pdf_by_ocr
(
dataset
:
Dataset
,
model_list
,
imageWriter
,
start_page_id
=
0
,
end_page_id
=
None
,
debug_mode
=
False
,
lang
=
None
,
):
return
pdf_parse_union
(
model_list
,
dataset
,
imageWriter
,
SupportedPdfParseMethod
.
OCR
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
debug_mode
=
debug_mode
,
lang
=
lang
,
)
magic_pdf/pdf_parse_by_txt.py
deleted
100644 → 0
View file @
ad9abc32
from
magic_pdf.config.enums
import
SupportedPdfParseMethod
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.pdf_parse_union_core_v2
import
pdf_parse_union
def
parse_pdf_by_txt
(
dataset
:
Dataset
,
model_list
,
imageWriter
,
start_page_id
=
0
,
end_page_id
=
None
,
debug_mode
=
False
,
lang
=
None
,
):
return
pdf_parse_union
(
model_list
,
dataset
,
imageWriter
,
SupportedPdfParseMethod
.
TXT
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
debug_mode
=
debug_mode
,
lang
=
lang
,
)
magic_pdf/pipe/AbsPipe.py
deleted
100644 → 0
View file @
ad9abc32
from
abc
import
ABC
,
abstractmethod
from
magic_pdf.config.drop_reason
import
DropReason
from
magic_pdf.config.make_content_config
import
DropMode
,
MakeMode
from
magic_pdf.data.data_reader_writer
import
DataWriter
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.dict2md.ocr_mkcontent
import
union_make
from
magic_pdf.filter.pdf_classify_by_type
import
classify
from
magic_pdf.filter.pdf_meta_scan
import
pdf_meta_scan
from
magic_pdf.libs.json_compressor
import
JsonCompressor
class
AbsPipe
(
ABC
):
"""txt和ocr处理的抽象类."""
PIP_OCR
=
'ocr'
PIP_TXT
=
'txt'
def
__init__
(
self
,
dataset
:
Dataset
,
model_list
:
list
,
image_writer
:
DataWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
self
.
dataset
=
dataset
self
.
model_list
=
model_list
self
.
image_writer
=
image_writer
self
.
pdf_mid_data
=
None
# 未压缩
self
.
is_debug
=
is_debug
self
.
start_page_id
=
start_page_id
self
.
end_page_id
=
end_page_id
self
.
lang
=
lang
self
.
layout_model
=
layout_model
self
.
formula_enable
=
formula_enable
self
.
table_enable
=
table_enable
def
get_compress_pdf_mid_data
(
self
):
return
JsonCompressor
.
compress_json
(
self
.
pdf_mid_data
)
@
abstractmethod
def
pipe_classify
(
self
):
"""有状态的分类."""
raise
NotImplementedError
@
abstractmethod
def
pipe_analyze
(
self
):
"""有状态的跑模型分析."""
raise
NotImplementedError
@
abstractmethod
def
pipe_parse
(
self
):
"""有状态的解析."""
raise
NotImplementedError
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
content_list
=
AbsPipe
.
mk_uni_format
(
self
.
get_compress_pdf_mid_data
(),
img_parent_path
,
drop_mode
)
return
content_list
def
pipe_mk_markdown
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
,
md_make_mode
=
MakeMode
.
MM_MD
):
md_content
=
AbsPipe
.
mk_markdown
(
self
.
get_compress_pdf_mid_data
(),
img_parent_path
,
drop_mode
,
md_make_mode
)
return
md_content
@
staticmethod
def
classify
(
pdf_bytes
:
bytes
)
->
str
:
"""根据pdf的元数据,判断是文本pdf,还是ocr pdf."""
pdf_meta
=
pdf_meta_scan
(
pdf_bytes
)
if
pdf_meta
.
get
(
'_need_drop'
,
False
):
# 如果返回了需要丢弃的标志,则抛出异常
raise
Exception
(
f
"pdf meta_scan need_drop,reason is
{
pdf_meta
[
'_drop_reason'
]
}
"
)
else
:
is_encrypted
=
pdf_meta
[
'is_encrypted'
]
is_needs_password
=
pdf_meta
[
'is_needs_password'
]
if
is_encrypted
or
is_needs_password
:
# 加密的,需要密码的,没有页面的,都不处理
raise
Exception
(
f
'pdf meta_scan need_drop,reason is
{
DropReason
.
ENCRYPTED
}
'
)
else
:
is_text_pdf
,
results
=
classify
(
pdf_meta
[
'total_page'
],
pdf_meta
[
'page_width_pts'
],
pdf_meta
[
'page_height_pts'
],
pdf_meta
[
'image_info_per_page'
],
pdf_meta
[
'text_len_per_page'
],
pdf_meta
[
'imgs_per_page'
],
pdf_meta
[
'text_layout_per_page'
],
pdf_meta
[
'invalid_chars'
],
)
if
is_text_pdf
:
return
AbsPipe
.
PIP_TXT
else
:
return
AbsPipe
.
PIP_OCR
@
staticmethod
def
mk_uni_format
(
compressed_pdf_mid_data
:
str
,
img_buket_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
)
->
list
:
"""根据pdf类型,生成统一格式content_list."""
pdf_mid_data
=
JsonCompressor
.
decompress_json
(
compressed_pdf_mid_data
)
pdf_info_list
=
pdf_mid_data
[
'pdf_info'
]
content_list
=
union_make
(
pdf_info_list
,
MakeMode
.
STANDARD_FORMAT
,
drop_mode
,
img_buket_path
)
return
content_list
@
staticmethod
def
mk_markdown
(
compressed_pdf_mid_data
:
str
,
img_buket_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
,
md_make_mode
=
MakeMode
.
MM_MD
)
->
list
:
"""根据pdf类型,markdown."""
pdf_mid_data
=
JsonCompressor
.
decompress_json
(
compressed_pdf_mid_data
)
pdf_info_list
=
pdf_mid_data
[
'pdf_info'
]
md_content
=
union_make
(
pdf_info_list
,
md_make_mode
,
drop_mode
,
img_buket_path
)
return
md_content
magic_pdf/pipe/OCRPipe.py
deleted
100644 → 0
View file @
ad9abc32
from
loguru
import
logger
from
magic_pdf.config.make_content_config
import
DropMode
,
MakeMode
from
magic_pdf.data.data_reader_writer
import
DataWriter
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.model.doc_analyze_by_custom_model
import
doc_analyze
from
magic_pdf.pipe.AbsPipe
import
AbsPipe
from
magic_pdf.user_api
import
parse_ocr_pdf
class
OCRPipe
(
AbsPipe
):
def
__init__
(
self
,
dataset
:
Dataset
,
model_list
:
list
,
image_writer
:
DataWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
):
super
().
__init__
(
dataset
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
,
layout_model
,
formula_enable
,
table_enable
,
)
def
pipe_classify
(
self
):
pass
def
pipe_analyze
(
self
):
self
.
infer_res
=
doc_analyze
(
self
.
dataset
,
ocr
=
True
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
,
)
def
pipe_parse
(
self
):
self
.
pdf_mid_data
=
parse_ocr_pdf
(
self
.
dataset
,
self
.
infer_res
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
,
)
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
result
=
super
().
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
logger
.
info
(
'ocr_pipe mk content list finished'
)
return
result
def
pipe_mk_markdown
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
,
md_make_mode
=
MakeMode
.
MM_MD
,
):
result
=
super
().
pipe_mk_markdown
(
img_parent_path
,
drop_mode
,
md_make_mode
)
logger
.
info
(
f
'ocr_pipe mk
{
md_make_mode
}
finished'
)
return
result
magic_pdf/pipe/TXTPipe.py
deleted
100644 → 0
View file @
ad9abc32
from
loguru
import
logger
from
magic_pdf.config.make_content_config
import
DropMode
,
MakeMode
from
magic_pdf.data.data_reader_writer
import
DataWriter
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.model.doc_analyze_by_custom_model
import
doc_analyze
from
magic_pdf.pipe.AbsPipe
import
AbsPipe
from
magic_pdf.user_api
import
parse_txt_pdf
class
TXTPipe
(
AbsPipe
):
def
__init__
(
self
,
dataset
:
Dataset
,
model_list
:
list
,
image_writer
:
DataWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
super
().
__init__
(
dataset
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
def
pipe_classify
(
self
):
pass
def
pipe_analyze
(
self
):
self
.
model_list
=
doc_analyze
(
self
.
dataset
,
ocr
=
False
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_parse
(
self
):
self
.
pdf_mid_data
=
parse_txt_pdf
(
self
.
dataset
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
result
=
super
().
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
logger
.
info
(
'txt_pipe mk content list finished'
)
return
result
def
pipe_mk_markdown
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
,
md_make_mode
=
MakeMode
.
MM_MD
):
result
=
super
().
pipe_mk_markdown
(
img_parent_path
,
drop_mode
,
md_make_mode
)
logger
.
info
(
f
'txt_pipe mk
{
md_make_mode
}
finished'
)
return
result
magic_pdf/pipe/UNIPipe.py
deleted
100644 → 0
View file @
ad9abc32
import
json
from
loguru
import
logger
from
magic_pdf.config.make_content_config
import
DropMode
,
MakeMode
from
magic_pdf.data.data_reader_writer
import
DataWriter
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.commons
import
join_path
from
magic_pdf.model.doc_analyze_by_custom_model
import
doc_analyze
from
magic_pdf.pipe.AbsPipe
import
AbsPipe
from
magic_pdf.user_api
import
parse_ocr_pdf
,
parse_union_pdf
class
UNIPipe
(
AbsPipe
):
def
__init__
(
self
,
dataset
:
Dataset
,
jso_useful_key
:
dict
,
image_writer
:
DataWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
):
self
.
pdf_type
=
jso_useful_key
[
'_pdf_type'
]
super
().
__init__
(
dataset
,
jso_useful_key
[
'model_list'
],
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
,
layout_model
,
formula_enable
,
table_enable
,
)
if
len
(
self
.
model_list
)
==
0
:
self
.
input_model_is_empty
=
True
else
:
self
.
input_model_is_empty
=
False
def
pipe_classify
(
self
):
self
.
pdf_type
=
AbsPipe
.
classify
(
self
.
pdf_bytes
)
def
pipe_analyze
(
self
):
if
self
.
pdf_type
==
self
.
PIP_TXT
:
self
.
model_list
=
doc_analyze
(
self
.
dataset
,
ocr
=
False
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
,
)
elif
self
.
pdf_type
==
self
.
PIP_OCR
:
self
.
model_list
=
doc_analyze
(
self
.
dataset
,
ocr
=
True
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
,
)
def
pipe_parse
(
self
):
if
self
.
pdf_type
==
self
.
PIP_TXT
:
self
.
pdf_mid_data
=
parse_union_pdf
(
self
.
dataset
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
,
)
elif
self
.
pdf_type
==
self
.
PIP_OCR
:
self
.
pdf_mid_data
=
parse_ocr_pdf
(
self
.
dataset
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
,
)
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
NONE_WITH_REASON
):
result
=
super
().
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
logger
.
info
(
'uni_pipe mk content list finished'
)
return
result
def
pipe_mk_markdown
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
,
md_make_mode
=
MakeMode
.
MM_MD
,
):
result
=
super
().
pipe_mk_markdown
(
img_parent_path
,
drop_mode
,
md_make_mode
)
logger
.
info
(
f
'uni_pipe mk
{
md_make_mode
}
finished'
)
return
result
if
__name__
==
'__main__'
:
# 测试
from
magic_pdf.data.data_reader_writer
import
DataReader
drw
=
DataReader
(
r
'D:/project/20231108code-clean'
)
pdf_file_path
=
r
'linshixuqiu\19983-00.pdf'
model_file_path
=
r
'linshixuqiu\19983-00.json'
pdf_bytes
=
drw
.
read
(
pdf_file_path
)
model_json_txt
=
drw
.
read
(
model_file_path
).
decode
()
model_list
=
json
.
loads
(
model_json_txt
)
write_path
=
r
'D:\project\20231108code-clean\linshixuqiu\19983-00'
img_bucket_path
=
'imgs'
img_writer
=
DataWriter
(
join_path
(
write_path
,
img_bucket_path
))
# pdf_type = UNIPipe.classify(pdf_bytes)
# jso_useful_key = {
# "_pdf_type": pdf_type,
# "model_list": model_list
# }
jso_useful_key
=
{
'_pdf_type'
:
''
,
'model_list'
:
model_list
}
pipe
=
UNIPipe
(
pdf_bytes
,
jso_useful_key
,
img_writer
)
pipe
.
pipe_classify
()
pipe
.
pipe_parse
()
md_content
=
pipe
.
pipe_mk_markdown
(
img_bucket_path
)
content_list
=
pipe
.
pipe_mk_uni_format
(
img_bucket_path
)
md_writer
=
DataWriter
(
write_path
)
md_writer
.
write_string
(
'19983-00.md'
,
md_content
)
md_writer
.
write_string
(
'19983-00.json'
,
json
.
dumps
(
pipe
.
pdf_mid_data
,
ensure_ascii
=
False
,
indent
=
4
)
)
md_writer
.
write_string
(
'19983-00.txt'
,
str
(
content_list
))
magic_pdf/pipe/__init__.py
deleted
100644 → 0
View file @
ad9abc32
magic_pdf/rw/AbsReaderWriter.py
deleted
100644 → 0
View file @
ad9abc32
from
abc
import
ABC
,
abstractmethod
class
AbsReaderWriter
(
ABC
):
MODE_TXT
=
"text"
MODE_BIN
=
"binary"
@
abstractmethod
def
read
(
self
,
path
:
str
,
mode
=
MODE_TXT
):
raise
NotImplementedError
@
abstractmethod
def
write
(
self
,
content
:
str
,
path
:
str
,
mode
=
MODE_TXT
):
raise
NotImplementedError
@
abstractmethod
def
read_offset
(
self
,
path
:
str
,
offset
=
0
,
limit
=
None
)
->
bytes
:
raise
NotImplementedError
magic_pdf/rw/DiskReaderWriter.py
deleted
100644 → 0
View file @
ad9abc32
import
os
from
magic_pdf.rw.AbsReaderWriter
import
AbsReaderWriter
from
loguru
import
logger
class
DiskReaderWriter
(
AbsReaderWriter
):
def
__init__
(
self
,
parent_path
,
encoding
=
"utf-8"
):
self
.
path
=
parent_path
self
.
encoding
=
encoding
def
read
(
self
,
path
,
mode
=
AbsReaderWriter
.
MODE_TXT
):
if
os
.
path
.
isabs
(
path
):
abspath
=
path
else
:
abspath
=
os
.
path
.
join
(
self
.
path
,
path
)
if
not
os
.
path
.
exists
(
abspath
):
logger
.
error
(
f
"file
{
abspath
}
not exists"
)
raise
Exception
(
f
"file
{
abspath
}
no exists"
)
if
mode
==
AbsReaderWriter
.
MODE_TXT
:
with
open
(
abspath
,
"r"
,
encoding
=
self
.
encoding
)
as
f
:
return
f
.
read
()
elif
mode
==
AbsReaderWriter
.
MODE_BIN
:
with
open
(
abspath
,
"rb"
)
as
f
:
return
f
.
read
()
else
:
raise
ValueError
(
"Invalid mode. Use 'text' or 'binary'."
)
def
write
(
self
,
content
,
path
,
mode
=
AbsReaderWriter
.
MODE_TXT
):
if
os
.
path
.
isabs
(
path
):
abspath
=
path
else
:
abspath
=
os
.
path
.
join
(
self
.
path
,
path
)
directory_path
=
os
.
path
.
dirname
(
abspath
)
if
not
os
.
path
.
exists
(
directory_path
):
os
.
makedirs
(
directory_path
)
if
mode
==
AbsReaderWriter
.
MODE_TXT
:
with
open
(
abspath
,
"w"
,
encoding
=
self
.
encoding
,
errors
=
"replace"
)
as
f
:
f
.
write
(
content
)
elif
mode
==
AbsReaderWriter
.
MODE_BIN
:
with
open
(
abspath
,
"wb"
)
as
f
:
f
.
write
(
content
)
else
:
raise
ValueError
(
"Invalid mode. Use 'text' or 'binary'."
)
def
read_offset
(
self
,
path
:
str
,
offset
=
0
,
limit
=
None
):
abspath
=
path
if
not
os
.
path
.
isabs
(
path
):
abspath
=
os
.
path
.
join
(
self
.
path
,
path
)
with
open
(
abspath
,
"rb"
)
as
f
:
f
.
seek
(
offset
)
return
f
.
read
(
limit
)
if
__name__
==
"__main__"
:
if
0
:
file_path
=
"io/test/example.txt"
drw
=
DiskReaderWriter
(
"D:\projects\papayfork\Magic-PDF\magic_pdf"
)
# 写入内容到文件
drw
.
write
(
b
"Hello, World!"
,
path
=
"io/test/example.txt"
,
mode
=
"binary"
)
# 从文件读取内容
content
=
drw
.
read
(
path
=
file_path
)
if
content
:
logger
.
info
(
f
"从
{
file_path
}
读取的内容:
{
content
}
"
)
if
1
:
drw
=
DiskReaderWriter
(
"/opt/data/pdf/resources/test/io/"
)
content_bin
=
drw
.
read_offset
(
"1.txt"
)
assert
content_bin
==
b
"ABCD!"
content_bin
=
drw
.
read_offset
(
"1.txt"
,
offset
=
1
,
limit
=
2
)
assert
content_bin
==
b
"BC"
magic_pdf/rw/S3ReaderWriter.py
deleted
100644 → 0
View file @
ad9abc32
from
magic_pdf.rw.AbsReaderWriter
import
AbsReaderWriter
from
magic_pdf.libs.commons
import
parse_bucket_key
,
join_path
import
boto3
from
loguru
import
logger
from
botocore.config
import
Config
class
S3ReaderWriter
(
AbsReaderWriter
):
def
__init__
(
self
,
ak
:
str
,
sk
:
str
,
endpoint_url
:
str
,
addressing_style
:
str
=
"auto"
,
parent_path
:
str
=
""
,
):
self
.
client
=
self
.
_get_client
(
ak
,
sk
,
endpoint_url
,
addressing_style
)
self
.
path
=
parent_path
def
_get_client
(
self
,
ak
:
str
,
sk
:
str
,
endpoint_url
:
str
,
addressing_style
:
str
):
s3_client
=
boto3
.
client
(
service_name
=
"s3"
,
aws_access_key_id
=
ak
,
aws_secret_access_key
=
sk
,
endpoint_url
=
endpoint_url
,
config
=
Config
(
s3
=
{
"addressing_style"
:
addressing_style
},
retries
=
{
"max_attempts"
:
5
,
"mode"
:
"standard"
},
),
)
return
s3_client
def
read
(
self
,
s3_relative_path
,
mode
=
AbsReaderWriter
.
MODE_TXT
,
encoding
=
"utf-8"
):
if
s3_relative_path
.
startswith
(
"s3://"
):
s3_path
=
s3_relative_path
else
:
s3_path
=
join_path
(
self
.
path
,
s3_relative_path
)
bucket_name
,
key
=
parse_bucket_key
(
s3_path
)
res
=
self
.
client
.
get_object
(
Bucket
=
bucket_name
,
Key
=
key
)
body
=
res
[
"Body"
].
read
()
if
mode
==
AbsReaderWriter
.
MODE_TXT
:
data
=
body
.
decode
(
encoding
)
# Decode bytes to text
elif
mode
==
AbsReaderWriter
.
MODE_BIN
:
data
=
body
else
:
raise
ValueError
(
"Invalid mode. Use 'text' or 'binary'."
)
return
data
def
write
(
self
,
content
,
s3_relative_path
,
mode
=
AbsReaderWriter
.
MODE_TXT
,
encoding
=
"utf-8"
):
if
s3_relative_path
.
startswith
(
"s3://"
):
s3_path
=
s3_relative_path
else
:
s3_path
=
join_path
(
self
.
path
,
s3_relative_path
)
if
mode
==
AbsReaderWriter
.
MODE_TXT
:
body
=
content
.
encode
(
encoding
)
# Encode text data as bytes
elif
mode
==
AbsReaderWriter
.
MODE_BIN
:
body
=
content
else
:
raise
ValueError
(
"Invalid mode. Use 'text' or 'binary'."
)
bucket_name
,
key
=
parse_bucket_key
(
s3_path
)
self
.
client
.
put_object
(
Body
=
body
,
Bucket
=
bucket_name
,
Key
=
key
)
logger
.
info
(
f
"内容已写入
{
s3_path
}
"
)
def
read_offset
(
self
,
path
:
str
,
offset
=
0
,
limit
=
None
)
->
bytes
:
if
path
.
startswith
(
"s3://"
):
s3_path
=
path
else
:
s3_path
=
join_path
(
self
.
path
,
path
)
bucket_name
,
key
=
parse_bucket_key
(
s3_path
)
range_header
=
(
f
"bytes=
{
offset
}
-
{
offset
+
limit
-
1
}
"
if
limit
else
f
"bytes=
{
offset
}
-"
)
res
=
self
.
client
.
get_object
(
Bucket
=
bucket_name
,
Key
=
key
,
Range
=
range_header
)
return
res
[
"Body"
].
read
()
if
__name__
==
"__main__"
:
if
0
:
# Config the connection info
ak
=
""
sk
=
""
endpoint_url
=
""
addressing_style
=
"auto"
bucket_name
=
""
# Create an S3ReaderWriter object
s3_reader_writer
=
S3ReaderWriter
(
ak
,
sk
,
endpoint_url
,
addressing_style
,
"s3://bucket_name/"
)
# Write text data to S3
text_data
=
"This is some text data"
s3_reader_writer
.
write
(
text_data
,
s3_relative_path
=
f
"s3://
{
bucket_name
}
/ebook/test/test.json"
,
mode
=
AbsReaderWriter
.
MODE_TXT
,
)
# Read text data from S3
text_data_read
=
s3_reader_writer
.
read
(
s3_relative_path
=
f
"s3://
{
bucket_name
}
/ebook/test/test.json"
,
mode
=
AbsReaderWriter
.
MODE_TXT
)
logger
.
info
(
f
"Read text data from S3:
{
text_data_read
}
"
)
# Write binary data to S3
binary_data
=
b
"This is some binary data"
s3_reader_writer
.
write
(
text_data
,
s3_relative_path
=
f
"s3://
{
bucket_name
}
/ebook/test/test.json"
,
mode
=
AbsReaderWriter
.
MODE_BIN
,
)
# Read binary data from S3
binary_data_read
=
s3_reader_writer
.
read
(
s3_relative_path
=
f
"s3://
{
bucket_name
}
/ebook/test/test.json"
,
mode
=
AbsReaderWriter
.
MODE_BIN
)
logger
.
info
(
f
"Read binary data from S3:
{
binary_data_read
}
"
)
# Range Read text data from S3
binary_data_read
=
s3_reader_writer
.
read_offset
(
path
=
f
"s3://
{
bucket_name
}
/ebook/test/test.json"
,
offset
=
0
,
limit
=
10
)
logger
.
info
(
f
"Read binary data from S3:
{
binary_data_read
}
"
)
if
1
:
import
os
import
json
ak
=
os
.
getenv
(
"AK"
,
""
)
sk
=
os
.
getenv
(
"SK"
,
""
)
endpoint_url
=
os
.
getenv
(
"ENDPOINT"
,
""
)
bucket
=
os
.
getenv
(
"S3_BUCKET"
,
""
)
prefix
=
os
.
getenv
(
"S3_PREFIX"
,
""
)
key_basename
=
os
.
getenv
(
"S3_KEY_BASENAME"
,
""
)
s3_reader_writer
=
S3ReaderWriter
(
ak
,
sk
,
endpoint_url
,
"auto"
,
f
"s3://
{
bucket
}
/
{
prefix
}
"
)
content_bin
=
s3_reader_writer
.
read_offset
(
key_basename
)
assert
content_bin
[:
10
]
==
b
'{"track_id'
assert
content_bin
[
-
10
:]
==
b
'r":null}}
\n
'
content_bin
=
s3_reader_writer
.
read_offset
(
key_basename
,
offset
=
424
,
limit
=
426
)
jso
=
json
.
dumps
(
content_bin
.
decode
(
"utf-8"
))
print
(
jso
)
magic_pdf/rw/__init__.py
deleted
100644 → 0
View file @
ad9abc32
magic_pdf/user_api.py
deleted
100644 → 0
View file @
ad9abc32
"""用户输入: model数组,每个元素代表一个页面 pdf在s3的路径 截图保存的s3位置.
然后:
1)根据s3路径,调用spark集群的api,拿到ak,sk,endpoint,构造出s3PDFReader
2)根据用户输入的s3地址,调用spark集群的api,拿到ak,sk,endpoint,构造出s3ImageWriter
其余部分至于构造s3cli, 获取ak,sk都在code-clean里写代码完成。不要反向依赖!!!
"""
from
loguru
import
logger
from
magic_pdf.data.data_reader_writer
import
DataWriter
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.version
import
__version__
from
magic_pdf.model.doc_analyze_by_custom_model
import
doc_analyze
from
magic_pdf.pdf_parse_by_ocr
import
parse_pdf_by_ocr
from
magic_pdf.pdf_parse_by_txt
import
parse_pdf_by_txt
from
magic_pdf.config.constants
import
PARSE_TYPE_TXT
,
PARSE_TYPE_OCR
def
parse_txt_pdf
(
dataset
:
Dataset
,
model_list
:
list
,
imageWriter
:
DataWriter
,
is_debug
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
*
args
,
**
kwargs
):
"""解析文本类pdf."""
pdf_info_dict
=
parse_pdf_by_txt
(
dataset
,
model_list
,
imageWriter
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
debug_mode
=
is_debug
,
lang
=
lang
,
)
pdf_info_dict
[
'_parse_type'
]
=
PARSE_TYPE_TXT
pdf_info_dict
[
'_version_name'
]
=
__version__
if
lang
is
not
None
:
pdf_info_dict
[
'_lang'
]
=
lang
return
pdf_info_dict
def
parse_ocr_pdf
(
dataset
:
Dataset
,
model_list
:
list
,
imageWriter
:
DataWriter
,
is_debug
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
*
args
,
**
kwargs
):
"""解析ocr类pdf."""
pdf_info_dict
=
parse_pdf_by_ocr
(
dataset
,
model_list
,
imageWriter
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
debug_mode
=
is_debug
,
lang
=
lang
,
)
pdf_info_dict
[
'_parse_type'
]
=
PARSE_TYPE_OCR
pdf_info_dict
[
'_version_name'
]
=
__version__
if
lang
is
not
None
:
pdf_info_dict
[
'_lang'
]
=
lang
return
pdf_info_dict
def
parse_union_pdf
(
dataset
:
Dataset
,
model_list
:
list
,
imageWriter
:
DataWriter
,
is_debug
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
*
args
,
**
kwargs
):
"""ocr和文本混合的pdf,全部解析出来."""
def
parse_pdf
(
method
):
try
:
return
method
(
dataset
,
model_list
,
imageWriter
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
debug_mode
=
is_debug
,
lang
=
lang
,
)
except
Exception
as
e
:
logger
.
exception
(
e
)
return
None
pdf_info_dict
=
parse_pdf
(
parse_pdf_by_txt
)
if
pdf_info_dict
is
None
or
pdf_info_dict
.
get
(
'_need_drop'
,
False
):
logger
.
warning
(
'parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr'
)
if
len
(
model_list
)
==
0
:
layout_model
=
kwargs
.
get
(
'layout_model'
,
None
)
formula_enable
=
kwargs
.
get
(
'formula_enable'
,
None
)
table_enable
=
kwargs
.
get
(
'table_enable'
,
None
)
infer_res
=
doc_analyze
(
dataset
,
ocr
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
,
)
model_list
=
infer_res
.
get_infer_res
()
pdf_info_dict
=
parse_pdf
(
parse_pdf_by_ocr
)
if
pdf_info_dict
is
None
:
raise
Exception
(
'Both parse_pdf_by_txt and parse_pdf_by_ocr failed.'
)
else
:
pdf_info_dict
[
'_parse_type'
]
=
PARSE_TYPE_OCR
else
:
pdf_info_dict
[
'_parse_type'
]
=
PARSE_TYPE_TXT
pdf_info_dict
[
'_version_name'
]
=
__version__
if
lang
is
not
None
:
pdf_info_dict
[
'_lang'
]
=
lang
return
pdf_info_dict
tests/test_cli/test_bench_gpu.py
View file @
d13f3c6d
import
pytest
import
os
import
os
from
conf
import
conf
from
conf
import
conf
import
os
import
os
import
json
import
json
from
magic_pdf.pipe.UNIPipe
import
UNIPipe
from
magic_pdf.rw.DiskReaderWriter
import
DiskReaderWriter
from
lib
import
calculate_score
from
lib
import
calculate_score
import
shutil
import
shutil
pdf_res_path
=
conf
.
conf
[
"pdf_res_path"
]
pdf_res_path
=
conf
.
conf
[
"pdf_res_path"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment