Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
ea2f8ea0
Unverified
Commit
ea2f8ea0
authored
Nov 21, 2024
by
Xiaomeng Zhao
Committed by
GitHub
Nov 21, 2024
Browse files
Merge branch 'dev' into dev
parents
e4810cee
23c8436e
Changes
59
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
637 additions
and
460 deletions
+637
-460
magic_pdf/config/constants.py
magic_pdf/config/constants.py
+53
-0
magic_pdf/config/drop_reason.py
magic_pdf/config/drop_reason.py
+35
-0
magic_pdf/config/drop_tag.py
magic_pdf/config/drop_tag.py
+19
-0
magic_pdf/config/make_content_config.py
magic_pdf/config/make_content_config.py
+11
-0
magic_pdf/config/model_block_type.py
magic_pdf/config/model_block_type.py
+2
-1
magic_pdf/config/ocr_content_type.py
magic_pdf/config/ocr_content_type.py
+0
-0
magic_pdf/data/read_api.py
magic_pdf/data/read_api.py
+1
-1
magic_pdf/dict2md/mkcontent.py
magic_pdf/dict2md/mkcontent.py
+226
-185
magic_pdf/dict2md/ocr_mkcontent.py
magic_pdf/dict2md/ocr_mkcontent.py
+7
-8
magic_pdf/filter/pdf_meta_scan.py
magic_pdf/filter/pdf_meta_scan.py
+101
-79
magic_pdf/integrations/rag/utils.py
magic_pdf/integrations/rag/utils.py
+4
-5
magic_pdf/libs/MakeContentConfig.py
magic_pdf/libs/MakeContentConfig.py
+0
-11
magic_pdf/libs/config_reader.py
magic_pdf/libs/config_reader.py
+5
-5
magic_pdf/libs/draw_bbox.py
magic_pdf/libs/draw_bbox.py
+3
-2
magic_pdf/libs/drop_reason.py
magic_pdf/libs/drop_reason.py
+0
-27
magic_pdf/libs/drop_tag.py
magic_pdf/libs/drop_tag.py
+0
-19
magic_pdf/libs/pdf_image_tools.py
magic_pdf/libs/pdf_image_tools.py
+9
-11
magic_pdf/model/magic_model.py
magic_pdf/model/magic_model.py
+13
-13
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+109
-59
magic_pdf/model/sub_modules/model_init.py
magic_pdf/model/sub_modules/model_init.py
+39
-34
No files found.
magic_pdf/
libs/C
onstants.py
→
magic_pdf/
config/c
onstants.py
View file @
ea2f8ea0
"""
span维度自定义字段
"""
"""span维度自定义字段."""
# span是否是跨页合并的
CROSS_PAGE
=
"
cross_page
"
CROSS_PAGE
=
'
cross_page
'
"""
block维度自定义字段
"""
# block中lines是否被删除
LINES_DELETED
=
"
lines_deleted
"
LINES_DELETED
=
'
lines_deleted
'
# table recognition max time default value
TABLE_MAX_TIME_VALUE
=
400
...
...
@@ -17,39 +15,39 @@ TABLE_MAX_TIME_VALUE = 400
TABLE_MAX_LEN
=
480
# table master structure dict
TABLE_MASTER_DICT
=
"
table_master_structure_dict.txt
"
TABLE_MASTER_DICT
=
'
table_master_structure_dict.txt
'
# table master dir
TABLE_MASTER_DIR
=
"
table_structure_tablemaster_infer/
"
TABLE_MASTER_DIR
=
'
table_structure_tablemaster_infer/
'
# pp detect model dir
DETECT_MODEL_DIR
=
"
ch_PP-OCRv4_det_infer
"
DETECT_MODEL_DIR
=
'
ch_PP-OCRv4_det_infer
'
# pp rec model dir
REC_MODEL_DIR
=
"
ch_PP-OCRv4_rec_infer
"
REC_MODEL_DIR
=
'
ch_PP-OCRv4_rec_infer
'
# pp rec char dict path
REC_CHAR_DICT
=
"
ppocr_keys_v1.txt
"
REC_CHAR_DICT
=
'
ppocr_keys_v1.txt
'
# pp rec copy rec directory
PP_REC_DIRECTORY
=
"
.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer
"
PP_REC_DIRECTORY
=
'
.paddleocr/whl/rec/ch/ch_PP-OCRv4_rec_infer
'
# pp rec copy det directory
PP_DET_DIRECTORY
=
"
.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer
"
PP_DET_DIRECTORY
=
'
.paddleocr/whl/det/ch/ch_PP-OCRv4_det_infer
'
class
MODEL_NAME
:
# pp table structure algorithm
TABLE_MASTER
=
"
tablemaster
"
TABLE_MASTER
=
'
tablemaster
'
# struct eqtable
STRUCT_EQTABLE
=
"
struct_eqtable
"
STRUCT_EQTABLE
=
'
struct_eqtable
'
DocLayout_YOLO
=
"
doclayout_yolo
"
DocLayout_YOLO
=
'
doclayout_yolo
'
LAYOUTLMv3
=
"
layoutlmv3
"
LAYOUTLMv3
=
'
layoutlmv3
'
YOLO_V8_MFD
=
"
yolo_v8_mfd
"
YOLO_V8_MFD
=
'
yolo_v8_mfd
'
UniMerNet_v2_Small
=
"
unimernet_small
"
UniMerNet_v2_Small
=
'
unimernet_small
'
RAPID_TABLE
=
"rapid_table"
\ No newline at end of file
RAPID_TABLE
=
'rapid_table'
magic_pdf/config/drop_reason.py
0 → 100644
View file @
ea2f8ea0
class
DropReason
:
TEXT_BLCOK_HOR_OVERLAP
=
'text_block_horizontal_overlap'
# 文字块有水平互相覆盖,导致无法准确定位文字顺序
USEFUL_BLOCK_HOR_OVERLAP
=
(
'useful_block_horizontal_overlap'
# 需保留的block水平覆盖
)
COMPLICATED_LAYOUT
=
'complicated_layout'
# 复杂的布局,暂时不支持
TOO_MANY_LAYOUT_COLUMNS
=
'too_many_layout_columns'
# 目前不支持分栏超过2列的
COLOR_BACKGROUND_TEXT_BOX
=
'color_background_text_box'
# 含有带色块的PDF,色块会改变阅读顺序,目前不支持带底色文字块的PDF。
HIGH_COMPUTATIONAL_lOAD_BY_IMGS
=
(
'high_computational_load_by_imgs'
# 含特殊图片,计算量太大,从而丢弃
)
HIGH_COMPUTATIONAL_lOAD_BY_SVGS
=
(
'high_computational_load_by_svgs'
# 特殊的SVG图,计算量太大,从而丢弃
)
HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES
=
'high_computational_load_by_total_pages'
# 计算量超过负荷,当前方法下计算量消耗过大
MISS_DOC_LAYOUT_RESULT
=
'missing doc_layout_result'
# 版面分析失败
Exception
=
'_exception'
# 解析中发生异常
ENCRYPTED
=
'encrypted'
# PDF是加密的
EMPTY_PDF
=
'total_page=0'
# PDF页面总数为0
NOT_IS_TEXT_PDF
=
'not_is_text_pdf'
# 不是文字版PDF,无法直接解析
DENSE_SINGLE_LINE_BLOCK
=
'dense_single_line_block'
# 无法清晰的分段
TITLE_DETECTION_FAILED
=
'title_detection_failed'
# 探测标题失败
TITLE_LEVEL_FAILED
=
(
'title_level_failed'
# 分析标题级别失败(例如一级、二级、三级标题)
)
PARA_SPLIT_FAILED
=
'para_split_failed'
# 识别段落失败
PARA_MERGE_FAILED
=
'para_merge_failed'
# 段落合并失败
NOT_ALLOW_LANGUAGE
=
'not_allow_language'
# 不支持的语种
SPECIAL_PDF
=
'special_pdf'
PSEUDO_SINGLE_COLUMN
=
'pseudo_single_column'
# 无法精确判断文字分栏
CAN_NOT_DETECT_PAGE_LAYOUT
=
'can_not_detect_page_layout'
# 无法分析页面的版面
NEGATIVE_BBOX_AREA
=
'negative_bbox_area'
# 缩放导致 bbox 面积为负
OVERLAP_BLOCKS_CAN_NOT_SEPARATION
=
(
'overlap_blocks_can_t_separation'
# 无法分离重叠的block
)
magic_pdf/config/drop_tag.py
0 → 100644
View file @
ea2f8ea0
COLOR_BG_HEADER_TXT_BLOCK
=
'color_background_header_txt_block'
PAGE_NO
=
'page-no'
# 页码
CONTENT_IN_FOOT_OR_HEADER
=
'in-foot-header-area'
# 页眉页脚内的文本
VERTICAL_TEXT
=
'vertical-text'
# 垂直文本
ROTATE_TEXT
=
'rotate-text'
# 旋转文本
EMPTY_SIDE_BLOCK
=
'empty-side-block'
# 边缘上的空白没有任何内容的block
ON_IMAGE_TEXT
=
'on-image-text'
# 文本在图片上
ON_TABLE_TEXT
=
'on-table-text'
# 文本在表格上
class
DropTag
:
PAGE_NUMBER
=
'page_no'
HEADER
=
'header'
FOOTER
=
'footer'
FOOTNOTE
=
'footnote'
NOT_IN_LAYOUT
=
'not_in_layout'
SPAN_OVERLAP
=
'span_overlap'
BLOCK_OVERLAP
=
'block_overlap'
magic_pdf/config/make_content_config.py
0 → 100644
View file @
ea2f8ea0
class
MakeMode
:
MM_MD
=
'mm_markdown'
NLP_MD
=
'nlp_markdown'
STANDARD_FORMAT
=
'standard_format'
class
DropMode
:
WHOLE_PDF
=
'whole_pdf'
SINGLE_PAGE
=
'single_page'
NONE
=
'none'
NONE_WITH_REASON
=
'none_with_reason'
magic_pdf/
libs/M
odel
B
lock
T
ype
Enum
.py
→
magic_pdf/
config/m
odel
_b
lock
_t
ype.py
View file @
ea2f8ea0
from
enum
import
Enum
class
ModelBlockTypeEnum
(
Enum
):
TITLE
=
0
PLAIN_TEXT
=
1
ABANDON
=
2
ISOLATE_FORMULA
=
8
EMBEDDING
=
13
ISOLATED
=
14
\ No newline at end of file
ISOLATED
=
14
magic_pdf/
libs
/ocr_content_type.py
→
magic_pdf/
config
/ocr_content_type.py
View file @
ea2f8ea0
File moved
magic_pdf/data/read_api.py
View file @
ea2f8ea0
...
...
@@ -35,7 +35,7 @@ def read_jsonl(
jsonl_d
=
[
json
.
loads
(
line
)
for
line
in
jsonl_bits
.
decode
().
split
(
'
\n
'
)
if
line
.
strip
()
]
for
d
in
jsonl_d
[:
5
]
:
for
d
in
jsonl_d
:
pdf_path
=
d
.
get
(
'file_location'
,
''
)
or
d
.
get
(
'path'
,
''
)
if
len
(
pdf_path
)
==
0
:
raise
EmptyData
(
'pdf file location is empty'
)
...
...
magic_pdf/dict2md/mkcontent.py
View file @
ea2f8ea0
This diff is collapsed.
Click to expand it.
magic_pdf/dict2md/ocr_mkcontent.py
View file @
ea2f8ea0
...
...
@@ -2,21 +2,20 @@ import re
from
loguru
import
logger
from
magic_pdf.config.make_content_config
import
DropMode
,
MakeMode
from
magic_pdf.config.ocr_content_type
import
BlockType
,
ContentType
from
magic_pdf.libs.commons
import
join_path
from
magic_pdf.libs.language
import
detect_lang
from
magic_pdf.libs.MakeContentConfig
import
DropMode
,
MakeMode
from
magic_pdf.libs.markdown_utils
import
ocr_escape_special_markdown_char
from
magic_pdf.libs.ocr_content_type
import
BlockType
,
ContentType
from
magic_pdf.para.para_split_v3
import
ListLineTag
def
__is_hyphen_at_line_end
(
line
):
"""
Check if a line ends with one or more letters followed by a hyphen.
"""Check if a line ends with one or more letters followed by a hyphen.
Args:
line (str): The line of text to check.
Returns:
bool: True if the line ends with one or more letters followed by a hyphen, False otherwise.
"""
...
...
@@ -163,7 +162,7 @@ def merge_para_with_text(para_block):
if
span_type
in
[
ContentType
.
Text
,
ContentType
.
InterlineEquation
]:
para_text
+=
content
# 中文/日语/韩文语境下,content间不需要空格分隔
elif
span_type
==
ContentType
.
InlineEquation
:
para_text
+=
f
"
{
content
}
"
para_text
+=
f
'
{
content
}
'
else
:
if
span_type
in
[
ContentType
.
Text
,
ContentType
.
InlineEquation
]:
# 如果span是line的最后一个且末尾带有-连字符,那么末尾不应该加空格,同时应该把-删除
...
...
@@ -172,7 +171,7 @@ def merge_para_with_text(para_block):
elif
len
(
content
)
==
1
and
content
not
in
[
'A'
,
'I'
,
'a'
,
'i'
]
and
not
content
.
isdigit
():
para_text
+=
content
else
:
# 西方文本语境下 content间需要空格分隔
para_text
+=
f
"
{
content
}
"
para_text
+=
f
'
{
content
}
'
elif
span_type
==
ContentType
.
InterlineEquation
:
para_text
+=
content
else
:
...
...
magic_pdf/filter/pdf_meta_scan.py
View file @
ea2f8ea0
"""
输入: s3路径,每行一个
输出: pdf文件元信息,包括每一页上的所有图片的长宽高,bbox位置
"""
"""输入: s3路径,每行一个 输出: pdf文件元信息,包括每一页上的所有图片的长宽高,bbox位置."""
import
sys
import
click
from
collections
import
Counter
from
magic_pdf.libs.commons
import
read_file
,
mymax
,
get_top_percent_list
from
magic_pdf.libs.commons
import
fitz
import
click
from
loguru
import
logger
from
collections
import
Counter
from
magic_pdf.libs.drop_reason
import
DropReason
from
magic_pdf.config.drop_reason
import
DropReason
from
magic_pdf.libs.commons
import
fitz
,
get_top_percent_list
,
mymax
,
read_file
from
magic_pdf.libs.language
import
detect_lang
from
magic_pdf.libs.pdf_check
import
detect_invalid_chars
...
...
@@ -19,8 +16,10 @@ junk_limit_min = 10
def
calculate_max_image_area_per_page
(
result
:
list
,
page_width_pts
,
page_height_pts
):
max_image_area_per_page
=
[
mymax
([(
x1
-
x0
)
*
(
y1
-
y0
)
for
x0
,
y0
,
x1
,
y1
,
_
in
page_img_sz
])
for
page_img_sz
in
result
]
max_image_area_per_page
=
[
mymax
([(
x1
-
x0
)
*
(
y1
-
y0
)
for
x0
,
y0
,
x1
,
y1
,
_
in
page_img_sz
])
for
page_img_sz
in
result
]
page_area
=
int
(
page_width_pts
)
*
int
(
page_height_pts
)
max_image_area_per_page
=
[
area
/
page_area
for
area
in
max_image_area_per_page
]
max_image_area_per_page
=
[
area
for
area
in
max_image_area_per_page
if
area
>
0.6
]
...
...
@@ -32,8 +31,10 @@ def process_image(page, junk_img_bojids=[]):
items
=
page
.
get_images
()
dedup
=
set
()
for
img
in
items
:
# 这里返回的是图片在page上的实际展示的大小。返回一个数组,每个元素第一部分是
img_bojid
=
img
[
0
]
# 在pdf文件中是全局唯一的,如果这个图反复出现在pdf里那么就可能是垃圾信息,例如水印、页眉页脚等
# 这里返回的是图片在page上的实际展示的大小。返回一个数组,每个元素第一部分是
img_bojid
=
img
[
0
]
# 在pdf文件中是全局唯一的,如果这个图反复出现在pdf里那么就可能是垃圾信息,例如水印、页眉页脚等
if
img_bojid
in
junk_img_bojids
:
# 如果是垃圾图像,就跳过
continue
recs
=
page
.
get_image_rects
(
img
,
transform
=
True
)
...
...
@@ -42,9 +43,17 @@ def process_image(page, junk_img_bojids=[]):
x0
,
y0
,
x1
,
y1
=
map
(
int
,
rec
)
width
=
x1
-
x0
height
=
y1
-
y0
if
(
x0
,
y0
,
x1
,
y1
,
img_bojid
)
in
dedup
:
# 这里面会出现一些重复的bbox,无需重复出现,需要去掉
if
(
x0
,
y0
,
x1
,
y1
,
img_bojid
,
)
in
dedup
:
# 这里面会出现一些重复的bbox,无需重复出现,需要去掉
continue
if
not
all
([
width
,
height
]):
# 长和宽任何一个都不能是0,否则这个图片不可见,没有实际意义
if
not
all
(
[
width
,
height
]
):
# 长和宽任何一个都不能是0,否则这个图片不可见,没有实际意义
continue
dedup
.
add
((
x0
,
y0
,
x1
,
y1
,
img_bojid
))
page_result
.
append
([
x0
,
y0
,
x1
,
y1
,
img_bojid
])
...
...
@@ -52,29 +61,33 @@ def process_image(page, junk_img_bojids=[]):
def
get_image_info
(
doc
:
fitz
.
Document
,
page_width_pts
,
page_height_pts
)
->
list
:
"""
返回每个页面里的图片的四元组,每个页面多个图片。
"""
返回每个页面里的图片的四元组,每个页面多个图片。
:param doc:
:return:
"""
# 使用 Counter 计数 img_bojid 的出现次数
#
使用 Counter 计数 img_bojid 的出现次数
img_bojid_counter
=
Counter
(
img
[
0
]
for
page
in
doc
for
img
in
page
.
get_images
())
# 找出出现次数超过 len(doc) 半数的 img_bojid
#
找出出现次数超过 len(doc) 半数的 img_bojid
junk_limit
=
max
(
len
(
doc
)
*
0.5
,
junk_limit_min
)
# 对一些页数比较少的进行豁免
junk_img_bojids
=
[
img_bojid
for
img_bojid
,
count
in
img_bojid_counter
.
items
()
if
count
>=
junk_limit
]
#todo 加个判断,用前十页就行,这些垃圾图片需要满足两个条件,不止出现的次数要足够多,而且图片占书页面积的比例要足够大,且图与图大小都差不多
#有两种扫描版,一种文字版,这里可能会有误判
#扫描版1:每页都有所有扫描页图片,特点是图占比大,每页展示1张
#扫描版2,每页存储的扫描页图片数量递增,特点是图占比大,每页展示1张,需要清空junklist跑前50页图片信息用于分类判断
#文字版1.每页存储所有图片,特点是图片占页面比例不大,每页展示可能为0也可能不止1张 这种pdf需要拿前10页抽样检测img大小和个数,如果符合需要清空junklist
junk_img_bojids
=
[
img_bojid
for
img_bojid
,
count
in
img_bojid_counter
.
items
()
if
count
>=
junk_limit
]
# todo 加个判断,用前十页就行,这些垃圾图片需要满足两个条件,不止出现的次数要足够多,而且图片占书页面积的比例要足够大,且图与图大小都差不多
# 有两种扫描版,一种文字版,这里可能会有误判
# 扫描版1:每页都有所有扫描页图片,特点是图占比大,每页展示1张
# 扫描版2,每页存储的扫描页图片数量递增,特点是图占比大,每页展示1张,需要清空junklist跑前50页图片信息用于分类判断
# 文 字版1.每页存储所有图片,特点是图片占页面比例不大,每页展示可能为0也可能不止1张 这种pdf需要拿前10页抽样检测img大小和个数,如果符合需要清空junklist
imgs_len_list
=
[
len
(
page
.
get_images
())
for
page
in
doc
]
special_limit_pages
=
10
# 统一用前十页结果做判断
#
统一用前十页结果做判断
result
=
[]
break_loop
=
False
for
i
,
page
in
enumerate
(
doc
):
...
...
@@ -82,12 +95,18 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
break
if
i
>=
special_limit_pages
:
break
page_result
=
process_image
(
page
)
# 这里不传junk_img_bojids,拿前十页所有图片信息用于后续分析
page_result
=
process_image
(
page
)
# 这里不传junk_img_bojids,拿前十页所有图片信息用于后续分析
result
.
append
(
page_result
)
for
item
in
result
:
if
not
any
(
item
):
# 如果任何一页没有图片,说明是个文字版,需要判断是否为特殊文字版
if
max
(
imgs_len_list
)
==
min
(
imgs_len_list
)
and
max
(
imgs_len_list
)
>=
junk_limit_min
:
# 如果是特殊文字版,就把junklist置空并break
if
not
any
(
item
):
# 如果任何一页没有图片,说明是个文字版,需要判断是否为特殊文字版
if
(
max
(
imgs_len_list
)
==
min
(
imgs_len_list
)
and
max
(
imgs_len_list
)
>=
junk_limit_min
):
# 如果是特殊文字版,就把junklist置空并break
junk_img_bojids
=
[]
else
:
# 不是特殊文字版,是个普通文字版,但是存在垃圾图片,不置空junklist
pass
...
...
@@ -98,20 +117,23 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
top_eighty_percent
=
get_top_percent_list
(
imgs_len_list
,
0.8
)
# 检查前80%的元素是否都相等
if
len
(
set
(
top_eighty_percent
))
==
1
and
max
(
imgs_len_list
)
>=
junk_limit_min
:
# # 如果前10页跑完都有图,根据每页图片数量是否相等判断是否需要清除junklist
# if max(imgs_len_list) == min(imgs_len_list) and max(imgs_len_list) >= junk_limit_min:
#前10页都有图,且每页数量一致,需要检测图片大小占页面的比例判断是否需要清除junklist
max_image_area_per_page
=
calculate_max_image_area_per_page
(
result
,
page_width_pts
,
page_height_pts
)
if
len
(
max_image_area_per_page
)
<
0.8
*
special_limit_pages
:
# 前10页不全是大图,说明可能是个文字版pdf,把垃圾图片list置空
# 前10页都有图,且每页数量一致,需要检测图片大小占页面的比例判断是否需要清除junklist
max_image_area_per_page
=
calculate_max_image_area_per_page
(
result
,
page_width_pts
,
page_height_pts
)
if
(
len
(
max_image_area_per_page
)
<
0.8
*
special_limit_pages
):
# 前10页不全是大图,说明可能是个文字版pdf,把垃圾图片list置空
junk_img_bojids
=
[]
else
:
# 前10页都有图,而且80%都是大图,且每页图片数量一致并都很多,说明是扫描版1,不需要清空junklist
pass
else
:
# 每页图片数量不一致,需要清掉junklist全量跑前50页图片
junk_img_bojids
=
[]
#正式进入取前50页图片的信息流程
#
正式进入取前50页图片的信息流程
result
=
[]
for
i
,
page
in
enumerate
(
doc
):
if
i
>=
scan_max_page
:
...
...
@@ -126,7 +148,7 @@ def get_image_info(doc: fitz.Document, page_width_pts, page_height_pts) -> list:
def
get_pdf_page_size_pts
(
doc
:
fitz
.
Document
):
page_cnt
=
len
(
doc
)
l
:
int
=
min
(
page_cnt
,
50
)
#把所有宽度和高度塞到两个list 分别取中位数(中间遇到了个在纵页里塞横页的pdf,导致宽高互换了)
#
把所有宽度和高度塞到两个list 分别取中位数(中间遇到了个在纵页里塞横页的pdf,导致宽高互换了)
page_width_list
=
[]
page_height_list
=
[]
for
i
in
range
(
l
):
...
...
@@ -152,8 +174,8 @@ def get_pdf_textlen_per_page(doc: fitz.Document):
# 拿所有text的blocks
# text_block = page.get_text("words")
# text_block_len = sum([len(t[4]) for t in text_block])
#拿所有text的str
text_block
=
page
.
get_text
(
"
text
"
)
#
拿所有text的str
text_block
=
page
.
get_text
(
'
text
'
)
text_block_len
=
len
(
text_block
)
# logger.info(f"page {page.number} text_block_len: {text_block_len}")
text_len_lst
.
append
(
text_block_len
)
...
...
@@ -162,15 +184,13 @@ def get_pdf_textlen_per_page(doc: fitz.Document):
def
get_pdf_text_layout_per_page
(
doc
:
fitz
.
Document
):
"""
根据PDF文档的每一页文本布局,判断该页的文本布局是横向、纵向还是未知。
"""根据PDF文档的每一页文本布局,判断该页的文本布局是横向、纵向还是未知。
Args:
doc (fitz.Document): PDF文档对象。
Returns:
List[str]: 每一页的文本布局(横向、纵向、未知)。
"""
text_layout_list
=
[]
...
...
@@ -180,11 +200,11 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
# 创建每一页的纵向和横向的文本行数计数器
vertical_count
=
0
horizontal_count
=
0
text_dict
=
page
.
get_text
(
"
dict
"
)
if
"
blocks
"
in
text_dict
:
for
block
in
text_dict
[
"
blocks
"
]:
text_dict
=
page
.
get_text
(
'
dict
'
)
if
'
blocks
'
in
text_dict
:
for
block
in
text_dict
[
'
blocks
'
]:
if
'lines'
in
block
:
for
line
in
block
[
"
lines
"
]:
for
line
in
block
[
'
lines
'
]:
# 获取line的bbox顶点坐标
x0
,
y0
,
x1
,
y1
=
line
[
'bbox'
]
# 计算bbox的宽高
...
...
@@ -199,8 +219,12 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
if
len
(
font_sizes
)
>
0
:
average_font_size
=
sum
(
font_sizes
)
/
len
(
font_sizes
)
else
:
average_font_size
=
10
# 有的line拿不到font_size,先定一个阈值100
if
area
<=
average_font_size
**
2
:
# 判断bbox的面积是否小于平均字体大小的平方,单字无法计算是横向还是纵向
average_font_size
=
(
10
# 有的line拿不到font_size,先定一个阈值100
)
if
(
area
<=
average_font_size
**
2
):
# 判断bbox的面积是否小于平均字体大小的平方,单字无法计算是横向还是纵向
continue
else
:
if
'wmode'
in
line
:
# 通过wmode判断文本方向
...
...
@@ -228,22 +252,22 @@ def get_pdf_text_layout_per_page(doc: fitz.Document):
# print(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}")
# 判断每一页的文本布局
if
vertical_count
==
0
and
horizontal_count
==
0
:
# 该页没有文本,无法判断
text_layout_list
.
append
(
"
unknow
"
)
text_layout_list
.
append
(
'
unknow
'
)
continue
else
:
if
vertical_count
>
horizontal_count
:
# 该页的文本纵向行数大于横向的
text_layout_list
.
append
(
"
vertical
"
)
text_layout_list
.
append
(
'
vertical
'
)
else
:
# 该页的文本横向行数大于纵向的
text_layout_list
.
append
(
"
horizontal
"
)
text_layout_list
.
append
(
'
horizontal
'
)
# logger.info(f"page_id: {page_id}, vertical_count: {vertical_count}, horizontal_count: {horizontal_count}")
return
text_layout_list
'''
定义一个自定义异常用来抛出单页svg太多的pdf
'''
"""
定义一个自定义异常用来抛出单页svg太多的pdf
"""
class
PageSvgsTooManyError
(
Exception
):
def
__init__
(
self
,
message
=
"
Page SVGs are too many
"
):
def
__init__
(
self
,
message
=
'
Page SVGs are too many
'
):
self
.
message
=
message
super
().
__init__
(
self
.
message
)
...
...
@@ -285,7 +309,7 @@ def get_language(doc: fitz.Document):
if
page_id
>=
scan_max_page
:
break
# 拿所有text的str
text_block
=
page
.
get_text
(
"
text
"
)
text_block
=
page
.
get_text
(
'
text
'
)
page_language
=
detect_lang
(
text_block
)
language_lst
.
append
(
page_language
)
...
...
@@ -299,9 +323,7 @@ def get_language(doc: fitz.Document):
def
check_invalid_chars
(
pdf_bytes
):
"""
乱码检测
"""
"""乱码检测."""
return
detect_invalid_chars
(
pdf_bytes
)
...
...
@@ -311,13 +333,13 @@ def pdf_meta_scan(pdf_bytes: bytes):
:param pdf_bytes: pdf文件的二进制数据
几个维度来评价:是否加密,是否需要密码,纸张大小,总页数,是否文字可提取
"""
doc
=
fitz
.
open
(
"
pdf
"
,
pdf_bytes
)
doc
=
fitz
.
open
(
'
pdf
'
,
pdf_bytes
)
is_needs_password
=
doc
.
needs_pass
is_encrypted
=
doc
.
is_encrypted
total_page
=
len
(
doc
)
if
total_page
==
0
:
logger
.
warning
(
f
"
drop this pdf, drop_reason:
{
DropReason
.
EMPTY_PDF
}
"
)
result
=
{
"
_need_drop
"
:
True
,
"
_drop_reason
"
:
DropReason
.
EMPTY_PDF
}
logger
.
warning
(
f
'
drop this pdf, drop_reason:
{
DropReason
.
EMPTY_PDF
}
'
)
result
=
{
'
_need_drop
'
:
True
,
'
_drop_reason
'
:
DropReason
.
EMPTY_PDF
}
return
result
else
:
page_width_pts
,
page_height_pts
=
get_pdf_page_size_pts
(
doc
)
...
...
@@ -328,7 +350,9 @@ def pdf_meta_scan(pdf_bytes: bytes):
imgs_per_page
=
get_imgs_per_page
(
doc
)
# logger.info(f"imgs_per_page: {imgs_per_page}")
image_info_per_page
,
junk_img_bojids
=
get_image_info
(
doc
,
page_width_pts
,
page_height_pts
)
image_info_per_page
,
junk_img_bojids
=
get_image_info
(
doc
,
page_width_pts
,
page_height_pts
)
# logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}")
text_len_per_page
=
get_pdf_textlen_per_page
(
doc
)
# logger.info(f"text_len_per_page: {text_len_per_page}")
...
...
@@ -341,20 +365,20 @@ def pdf_meta_scan(pdf_bytes: bytes):
# 最后输出一条json
res
=
{
"
is_needs_password
"
:
is_needs_password
,
"
is_encrypted
"
:
is_encrypted
,
"
total_page
"
:
total_page
,
"
page_width_pts
"
:
int
(
page_width_pts
),
"
page_height_pts
"
:
int
(
page_height_pts
),
"
image_info_per_page
"
:
image_info_per_page
,
"
text_len_per_page
"
:
text_len_per_page
,
"
text_layout_per_page
"
:
text_layout_per_page
,
"
text_language
"
:
text_language
,
'
is_needs_password
'
:
is_needs_password
,
'
is_encrypted
'
:
is_encrypted
,
'
total_page
'
:
total_page
,
'
page_width_pts
'
:
int
(
page_width_pts
),
'
page_height_pts
'
:
int
(
page_height_pts
),
'
image_info_per_page
'
:
image_info_per_page
,
'
text_len_per_page
'
:
text_len_per_page
,
'
text_layout_per_page
'
:
text_layout_per_page
,
'
text_language
'
:
text_language
,
# "svgs_per_page": svgs_per_page,
"
imgs_per_page
"
:
imgs_per_page
,
# 增加每页img数量list
"
junk_img_bojids
"
:
junk_img_bojids
,
# 增加垃圾图片的bojid list
"
invalid_chars
"
:
invalid_chars
,
"
metadata
"
:
doc
.
metadata
'
imgs_per_page
'
:
imgs_per_page
,
# 增加每页img数量list
'
junk_img_bojids
'
:
junk_img_bojids
,
# 增加垃圾图片的bojid list
'
invalid_chars
'
:
invalid_chars
,
'
metadata
'
:
doc
.
metadata
,
}
# logger.info(json.dumps(res, ensure_ascii=False))
return
res
...
...
@@ -364,14 +388,12 @@ def pdf_meta_scan(pdf_bytes: bytes):
@
click
.
option
(
'--s3-pdf-path'
,
help
=
's3上pdf文件的路径'
)
@
click
.
option
(
'--s3-profile'
,
help
=
's3上的profile'
)
def
main
(
s3_pdf_path
:
str
,
s3_profile
:
str
):
"""
"""
""""""
try
:
file_content
=
read_file
(
s3_pdf_path
,
s3_profile
)
pdf_meta_scan
(
file_content
)
except
Exception
as
e
:
print
(
f
"
ERROR:
{
s3_pdf_path
}
,
{
e
}
"
,
file
=
sys
.
stderr
)
print
(
f
'
ERROR:
{
s3_pdf_path
}
,
{
e
}
'
,
file
=
sys
.
stderr
)
logger
.
exception
(
e
)
...
...
@@ -381,7 +403,7 @@ if __name__ == '__main__':
# "D:\project/20231108code-clean\pdf_cost_time\竖排例子\三国演义_繁体竖排版.pdf"
# "D:\project/20231108code-clean\pdf_cost_time\scihub\scihub_86800000\libgen.scimag86880000-86880999.zip_10.1021/acsami.1c03109.s002.pdf"
# "D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_18600000/libgen.scimag18645000-18645999.zip_10.1021/om3006239.pdf"
# file_content = read_file("D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_31000000/libgen.scimag31098000-31098999.zip_10.1109/isit.2006.261791.pdf","")
# file_content = read_file("D:/project/20231108code-clean/pdf_cost_time/scihub/scihub_31000000/libgen.scimag31098000-31098999.zip_10.1109/isit.2006.261791.pdf","")
# noqa: E501
# file_content = read_file("D:\project/20231108code-clean\pdf_cost_time\竖排例子\净空法师_大乘无量寿.pdf","")
# doc = fitz.open("pdf", file_content)
# text_layout_lst = get_pdf_text_layout_per_page(doc)
...
...
magic_pdf/integrations/rag/utils.py
View file @
ea2f8ea0
...
...
@@ -5,14 +5,13 @@ from pathlib import Path
from
loguru
import
logger
import
magic_pdf.model
as
model_config
from
magic_pdf.config.ocr_content_type
import
BlockType
,
ContentType
from
magic_pdf.data.data_reader_writer
import
FileBasedDataReader
from
magic_pdf.dict2md.ocr_mkcontent
import
merge_para_with_text
from
magic_pdf.integrations.rag.type
import
(
CategoryType
,
ContentObject
,
ElementRelation
,
ElementRelType
,
LayoutElements
,
LayoutElementsExtra
,
PageInfo
)
from
magic_pdf.libs.ocr_content_type
import
BlockType
,
ContentType
from
magic_pdf.rw.AbsReaderWriter
import
AbsReaderWriter
from
magic_pdf.rw.DiskReaderWriter
import
DiskReaderWriter
from
magic_pdf.tools.common
import
do_parse
,
prepare_env
...
...
@@ -224,8 +223,8 @@ def inference(path, output_dir, method):
str
(
Path
(
path
).
stem
),
method
)
def
read_fn
(
path
):
disk_rw
=
DiskReaderWrit
er
(
os
.
path
.
dirname
(
path
))
return
disk_rw
.
read
(
os
.
path
.
basename
(
path
)
,
AbsReaderWriter
.
MODE_BIN
)
disk_rw
=
FileBasedDataRead
er
(
os
.
path
.
dirname
(
path
))
return
disk_rw
.
read
(
os
.
path
.
basename
(
path
))
def
parse_doc
(
doc_path
:
str
):
try
:
...
...
magic_pdf/libs/MakeContentConfig.py
deleted
100644 → 0
View file @
e4810cee
class
MakeMode
:
MM_MD
=
"mm_markdown"
NLP_MD
=
"nlp_markdown"
STANDARD_FORMAT
=
"standard_format"
class
DropMode
:
WHOLE_PDF
=
"whole_pdf"
SINGLE_PAGE
=
"single_page"
NONE
=
"none"
NONE_WITH_REASON
=
"none_with_reason"
magic_pdf/libs/config_reader.py
View file @
ea2f8ea0
...
...
@@ -5,7 +5,7 @@ import os
from
loguru
import
logger
from
magic_pdf.
libs.C
onstants
import
MODEL_NAME
from
magic_pdf.
config.c
onstants
import
MODEL_NAME
from
magic_pdf.libs.commons
import
parse_bucket_key
# 定义配置文件名常量
...
...
@@ -99,7 +99,7 @@ def get_table_recog_config():
def
get_layout_config
():
config
=
read_config
()
layout_config
=
config
.
get
(
"
layout-config
"
)
layout_config
=
config
.
get
(
'
layout-config
'
)
if
layout_config
is
None
:
logger
.
warning
(
f
"'layout-config' not found in
{
CONFIG_FILE_NAME
}
, use '
{
MODEL_NAME
.
LAYOUTLMv3
}
' as default"
)
return
json
.
loads
(
f
'{{"model": "
{
MODEL_NAME
.
LAYOUTLMv3
}
"}}'
)
...
...
@@ -109,7 +109,7 @@ def get_layout_config():
def
get_formula_config
():
config
=
read_config
()
formula_config
=
config
.
get
(
"
formula-config
"
)
formula_config
=
config
.
get
(
'
formula-config
'
)
if
formula_config
is
None
:
logger
.
warning
(
f
"'formula-config' not found in
{
CONFIG_FILE_NAME
}
, use 'True' as default"
)
return
json
.
loads
(
f
'{{"mfd_model": "
{
MODEL_NAME
.
YOLO_V8_MFD
}
","mfr_model": "
{
MODEL_NAME
.
UniMerNet_v2_Small
}
","enable": true}}'
)
...
...
@@ -117,5 +117,5 @@ def get_formula_config():
return
formula_config
if
__name__
==
"
__main__
"
:
ak
,
sk
,
endpoint
=
get_s3_config
(
"
llm-raw
"
)
if
__name__
==
'
__main__
'
:
ak
,
sk
,
endpoint
=
get_s3_config
(
'
llm-raw
'
)
magic_pdf/libs/draw_bbox.py
View file @
ea2f8ea0
from
magic_pdf.config.constants
import
CROSS_PAGE
from
magic_pdf.config.ocr_content_type
import
(
BlockType
,
CategoryId
,
ContentType
)
from
magic_pdf.data.dataset
import
PymuDocDataset
from
magic_pdf.libs.commons
import
fitz
# PyMuPDF
from
magic_pdf.libs.Constants
import
CROSS_PAGE
from
magic_pdf.libs.ocr_content_type
import
BlockType
,
CategoryId
,
ContentType
from
magic_pdf.model.magic_model
import
MagicModel
...
...
magic_pdf/libs/drop_reason.py
deleted
100644 → 0
View file @
e4810cee
class
DropReason
:
TEXT_BLCOK_HOR_OVERLAP
=
"text_block_horizontal_overlap"
# 文字块有水平互相覆盖,导致无法准确定位文字顺序
USEFUL_BLOCK_HOR_OVERLAP
=
"useful_block_horizontal_overlap"
# 需保留的block水平覆盖
COMPLICATED_LAYOUT
=
"complicated_layout"
# 复杂的布局,暂时不支持
TOO_MANY_LAYOUT_COLUMNS
=
"too_many_layout_columns"
# 目前不支持分栏超过2列的
COLOR_BACKGROUND_TEXT_BOX
=
"color_background_text_box"
# 含有带色块的PDF,色块会改变阅读顺序,目前不支持带底色文字块的PDF。
HIGH_COMPUTATIONAL_lOAD_BY_IMGS
=
"high_computational_load_by_imgs"
# 含特殊图片,计算量太大,从而丢弃
HIGH_COMPUTATIONAL_lOAD_BY_SVGS
=
"high_computational_load_by_svgs"
# 特殊的SVG图,计算量太大,从而丢弃
HIGH_COMPUTATIONAL_lOAD_BY_TOTAL_PAGES
=
"high_computational_load_by_total_pages"
# 计算量超过负荷,当前方法下计算量消耗过大
MISS_DOC_LAYOUT_RESULT
=
"missing doc_layout_result"
# 版面分析失败
Exception
=
"_exception"
# 解析中发生异常
ENCRYPTED
=
"encrypted"
# PDF是加密的
EMPTY_PDF
=
"total_page=0"
# PDF页面总数为0
NOT_IS_TEXT_PDF
=
"not_is_text_pdf"
# 不是文字版PDF,无法直接解析
DENSE_SINGLE_LINE_BLOCK
=
"dense_single_line_block"
# 无法清晰的分段
TITLE_DETECTION_FAILED
=
"title_detection_failed"
# 探测标题失败
TITLE_LEVEL_FAILED
=
"title_level_failed"
# 分析标题级别失败(例如一级、二级、三级标题)
PARA_SPLIT_FAILED
=
"para_split_failed"
# 识别段落失败
PARA_MERGE_FAILED
=
"para_merge_failed"
# 段落合并失败
NOT_ALLOW_LANGUAGE
=
"not_allow_language"
# 不支持的语种
SPECIAL_PDF
=
"special_pdf"
PSEUDO_SINGLE_COLUMN
=
"pseudo_single_column"
# 无法精确判断文字分栏
CAN_NOT_DETECT_PAGE_LAYOUT
=
"can_not_detect_page_layout"
# 无法分析页面的版面
NEGATIVE_BBOX_AREA
=
"negative_bbox_area"
# 缩放导致 bbox 面积为负
OVERLAP_BLOCKS_CAN_NOT_SEPARATION
=
"overlap_blocks_can_t_separation"
# 无法分离重叠的block
\ No newline at end of file
magic_pdf/libs/drop_tag.py
deleted
100644 → 0
View file @
e4810cee
COLOR_BG_HEADER_TXT_BLOCK
=
"color_background_header_txt_block"
PAGE_NO
=
"page-no"
# 页码
CONTENT_IN_FOOT_OR_HEADER
=
'in-foot-header-area'
# 页眉页脚内的文本
VERTICAL_TEXT
=
'vertical-text'
# 垂直文本
ROTATE_TEXT
=
'rotate-text'
# 旋转文本
EMPTY_SIDE_BLOCK
=
'empty-side-block'
# 边缘上的空白没有任何内容的block
ON_IMAGE_TEXT
=
'on-image-text'
# 文本在图片上
ON_TABLE_TEXT
=
'on-table-text'
# 文本在表格上
class
DropTag
:
PAGE_NUMBER
=
"page_no"
HEADER
=
"header"
FOOTER
=
"footer"
FOOTNOTE
=
"footnote"
NOT_IN_LAYOUT
=
"not_in_layout"
SPAN_OVERLAP
=
"span_overlap"
BLOCK_OVERLAP
=
"block_overlap"
magic_pdf/libs/pdf_image_tools.py
View file @
ea2f8ea0
from
magic_pdf.rw.AbsReaderWriter
import
AbsReaderWriter
from
magic_pdf.libs.commons
import
fitz
from
magic_pdf.libs.commons
import
join_path
from
magic_pdf.data.data_reader_writer
import
DataWriter
from
magic_pdf.libs.commons
import
fitz
,
join_path
from
magic_pdf.libs.hash_utils
import
compute_sha256
def
cut_image
(
bbox
:
tuple
,
page_num
:
int
,
page
:
fitz
.
Page
,
return_path
,
imageWriter
:
AbsReaderWriter
):
"""
从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径
save_path:需要同时支持s3和本地, 图片存放在save_path下,文件名是: {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。
"""
def
cut_image
(
bbox
:
tuple
,
page_num
:
int
,
page
:
fitz
.
Page
,
return_path
,
imageWriter
:
DataWriter
):
"""从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 save_path:需要同时支持s3和本地,
图片存放在save_path下,文件名是:
{page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。"""
# 拼接文件名
filename
=
f
"
{
page_num
}
_
{
int
(
bbox
[
0
])
}
_
{
int
(
bbox
[
1
])
}
_
{
int
(
bbox
[
2
])
}
_
{
int
(
bbox
[
3
])
}
"
filename
=
f
'
{
page_num
}
_
{
int
(
bbox
[
0
])
}
_
{
int
(
bbox
[
1
])
}
_
{
int
(
bbox
[
2
])
}
_
{
int
(
bbox
[
3
])
}
'
# 老版本返回不带bucket的路径
img_path
=
join_path
(
return_path
,
filename
)
if
return_path
is
not
None
else
None
# 新版本生成平铺路径
img_hash256_path
=
f
"
{
compute_sha256
(
img_path
)
}
.jpg
"
img_hash256_path
=
f
'
{
compute_sha256
(
img_path
)
}
.jpg
'
# 将坐标转换为fitz.Rect对象
rect
=
fitz
.
Rect
(
*
bbox
)
...
...
@@ -28,6 +26,6 @@ def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWri
byte_data
=
pix
.
tobytes
(
output
=
'jpeg'
,
jpg_quality
=
95
)
imageWriter
.
write
(
byte_data
,
img_hash256_path
,
AbsReaderWriter
.
MODE_BIN
)
imageWriter
.
write
(
img_hash256_path
,
byte_data
)
return
img_hash256_path
magic_pdf/model/magic_model.py
View file @
ea2f8ea0
import
enum
import
json
from
magic_pdf.config.model_block_type
import
ModelBlockTypeEnum
from
magic_pdf.config.ocr_content_type
import
CategoryId
,
ContentType
from
magic_pdf.data.data_reader_writer
import
(
FileBasedDataReader
,
FileBasedDataWriter
)
from
magic_pdf.data.dataset
import
Dataset
from
magic_pdf.libs.boxbase
import
(
_is_in
,
_is_part_overlap
,
bbox_distance
,
bbox_relative_pos
,
box_area
,
calculate_iou
,
...
...
@@ -9,11 +13,7 @@ from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
from
magic_pdf.libs.commons
import
fitz
,
join_path
from
magic_pdf.libs.coordinate_transform
import
get_scale_ratio
from
magic_pdf.libs.local_math
import
float_gt
from
magic_pdf.libs.ModelBlockTypeEnum
import
ModelBlockTypeEnum
from
magic_pdf.libs.ocr_content_type
import
CategoryId
,
ContentType
from
magic_pdf.pre_proc.remove_bbox_overlap
import
_remove_overlap_between_bbox
from
magic_pdf.rw.AbsReaderWriter
import
AbsReaderWriter
from
magic_pdf.rw.DiskReaderWriter
import
DiskReaderWriter
CAPATION_OVERLAP_AREA_RATIO
=
0.6
MERGE_BOX_OVERLAP_AREA_RATIO
=
1.1
...
...
@@ -1050,27 +1050,27 @@ class MagicModel:
if
__name__
==
'__main__'
:
drw
=
DiskReaderWrit
er
(
r
'D:/project/20231108code-clean'
)
drw
=
FileBasedDataRead
er
(
r
'D:/project/20231108code-clean'
)
if
0
:
pdf_file_path
=
r
'linshixuqiu\19983-00.pdf'
model_file_path
=
r
'linshixuqiu\19983-00_new.json'
pdf_bytes
=
drw
.
read
(
pdf_file_path
,
AbsReaderWriter
.
MODE_BIN
)
model_json_txt
=
drw
.
read
(
model_file_path
,
AbsReaderWriter
.
MODE_TXT
)
pdf_bytes
=
drw
.
read
(
pdf_file_path
)
model_json_txt
=
drw
.
read
(
model_file_path
).
decode
(
)
model_list
=
json
.
loads
(
model_json_txt
)
write_path
=
r
'D:\project\20231108code-clean\linshixuqiu\19983-00'
img_bucket_path
=
'imgs'
img_writer
=
DiskReader
Writer
(
join_path
(
write_path
,
img_bucket_path
))
img_writer
=
FileBasedData
Writer
(
join_path
(
write_path
,
img_bucket_path
))
pdf_docs
=
fitz
.
open
(
'pdf'
,
pdf_bytes
)
magic_model
=
MagicModel
(
model_list
,
pdf_docs
)
if
1
:
from
magic_pdf.data.dataset
import
PymuDocDataset
model_list
=
json
.
loads
(
drw
.
read
(
'/opt/data/pdf/20240418/j.chroma.2009.03.042.json'
)
)
pdf_bytes
=
drw
.
read
(
'/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf'
,
AbsReaderWriter
.
MODE_BIN
)
pdf_docs
=
fitz
.
open
(
'pdf'
,
pdf_bytes
)
magic_model
=
MagicModel
(
model_list
,
pdf_docs
)
pdf_bytes
=
drw
.
read
(
'/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf'
)
magic_model
=
MagicModel
(
model_list
,
PymuDocDataset
(
pdf_bytes
))
for
i
in
range
(
7
):
print
(
magic_model
.
get_imgs
(
i
))
magic_pdf/model/pdf_extract_kit.py
View file @
ea2f8ea0
import
numpy
as
np
import
torch
from
loguru
import
logger
# flake8: noqa
import
os
import
time
import
cv2
import
numpy
as
np
import
torch
import
yaml
from
loguru
import
logger
from
PIL
import
Image
os
.
environ
[
'NO_ALBUMENTATIONS_UPDATE'
]
=
'1'
# 禁止albumentations检查更新
...
...
@@ -13,20 +15,21 @@ os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
try
:
import
torchtext
if
torchtext
.
__version__
>=
"
0.18.0
"
:
if
torchtext
.
__version__
>=
'
0.18.0
'
:
torchtext
.
disable_torchtext_deprecation_warning
()
except
ImportError
:
pass
from
magic_pdf.
libs.C
onstants
import
*
from
magic_pdf.
config.c
onstants
import
*
from
magic_pdf.model.model_list
import
AtomicModel
from
magic_pdf.model.sub_modules.model_init
import
AtomModelSingleton
from
magic_pdf.model.sub_modules.model_utils
import
get_res_list_from_layout_res
,
crop_img
,
clean_vram
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
get_adjusted_mfdetrec_res
,
get_ocr_result_list
from
magic_pdf.model.sub_modules.model_utils
import
(
clean_vram
,
crop_img
,
get_res_list_from_layout_res
)
from
magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils
import
(
get_adjusted_mfdetrec_res
,
get_ocr_result_list
)
class
CustomPEKModel
:
def
__init__
(
self
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
**
kwargs
):
"""
======== model init ========
...
...
@@ -41,42 +44,54 @@ class CustomPEKModel:
model_config_dir
=
os
.
path
.
join
(
root_dir
,
'resources'
,
'model_config'
)
# 构建 model_configs.yaml 文件的完整路径
config_path
=
os
.
path
.
join
(
model_config_dir
,
'model_configs.yaml'
)
with
open
(
config_path
,
"r"
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
config_path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
self
.
configs
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
# 初始化解析配置
# layout config
self
.
layout_config
=
kwargs
.
get
(
"layout_config"
)
self
.
layout_model_name
=
self
.
layout_config
.
get
(
"model"
,
MODEL_NAME
.
DocLayout_YOLO
)
self
.
layout_config
=
kwargs
.
get
(
'layout_config'
)
self
.
layout_model_name
=
self
.
layout_config
.
get
(
'model'
,
MODEL_NAME
.
DocLayout_YOLO
)
# formula config
self
.
formula_config
=
kwargs
.
get
(
"formula_config"
)
self
.
mfd_model_name
=
self
.
formula_config
.
get
(
"mfd_model"
,
MODEL_NAME
.
YOLO_V8_MFD
)
self
.
mfr_model_name
=
self
.
formula_config
.
get
(
"mfr_model"
,
MODEL_NAME
.
UniMerNet_v2_Small
)
self
.
apply_formula
=
self
.
formula_config
.
get
(
"enable"
,
True
)
self
.
formula_config
=
kwargs
.
get
(
'formula_config'
)
self
.
mfd_model_name
=
self
.
formula_config
.
get
(
'mfd_model'
,
MODEL_NAME
.
YOLO_V8_MFD
)
self
.
mfr_model_name
=
self
.
formula_config
.
get
(
'mfr_model'
,
MODEL_NAME
.
UniMerNet_v2_Small
)
self
.
apply_formula
=
self
.
formula_config
.
get
(
'enable'
,
True
)
# table config
self
.
table_config
=
kwargs
.
get
(
"
table_config
"
)
self
.
apply_table
=
self
.
table_config
.
get
(
"
enable
"
,
False
)
self
.
table_max_time
=
self
.
table_config
.
get
(
"
max_time
"
,
TABLE_MAX_TIME_VALUE
)
self
.
table_model_name
=
self
.
table_config
.
get
(
"
model
"
,
MODEL_NAME
.
RAPID_TABLE
)
self
.
table_config
=
kwargs
.
get
(
'
table_config
'
)
self
.
apply_table
=
self
.
table_config
.
get
(
'
enable
'
,
False
)
self
.
table_max_time
=
self
.
table_config
.
get
(
'
max_time
'
,
TABLE_MAX_TIME_VALUE
)
self
.
table_model_name
=
self
.
table_config
.
get
(
'
model
'
,
MODEL_NAME
.
RAPID_TABLE
)
# ocr config
self
.
apply_ocr
=
ocr
self
.
lang
=
kwargs
.
get
(
"
lang
"
,
None
)
self
.
lang
=
kwargs
.
get
(
'
lang
'
,
None
)
logger
.
info
(
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
"apply_table: {}, table_model: {}, lang: {}"
.
format
(
self
.
layout_model_name
,
self
.
apply_formula
,
self
.
apply_ocr
,
self
.
apply_table
,
self
.
table_model_name
,
self
.
lang
'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
'apply_table: {}, table_model: {}, lang: {}'
.
format
(
self
.
layout_model_name
,
self
.
apply_formula
,
self
.
apply_ocr
,
self
.
apply_table
,
self
.
table_model_name
,
self
.
lang
,
)
)
# 初始化解析方案
self
.
device
=
kwargs
.
get
(
"device"
,
"cpu"
)
logger
.
info
(
"using device: {}"
.
format
(
self
.
device
))
models_dir
=
kwargs
.
get
(
"models_dir"
,
os
.
path
.
join
(
root_dir
,
"resources"
,
"models"
))
logger
.
info
(
"using models_dir: {}"
.
format
(
models_dir
))
self
.
device
=
kwargs
.
get
(
'device'
,
'cpu'
)
logger
.
info
(
'using device: {}'
.
format
(
self
.
device
))
models_dir
=
kwargs
.
get
(
'models_dir'
,
os
.
path
.
join
(
root_dir
,
'resources'
,
'models'
)
)
logger
.
info
(
'using models_dir: {}'
.
format
(
models_dir
))
atom_model_manager
=
AtomModelSingleton
()
...
...
@@ -85,18 +100,24 @@ class CustomPEKModel:
# 初始化公式检测模型
self
.
mfd_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
MFD
,
mfd_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
self
.
mfd_model_name
])),
device
=
self
.
device
mfd_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
mfd_model_name
]
)
),
device
=
self
.
device
,
)
# 初始化公式解析模型
mfr_weight_dir
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
self
.
mfr_model_name
]))
mfr_cfg_path
=
str
(
os
.
path
.
join
(
model_config_dir
,
"UniMERNet"
,
"demo.yaml"
))
mfr_weight_dir
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
mfr_model_name
])
)
mfr_cfg_path
=
str
(
os
.
path
.
join
(
model_config_dir
,
'UniMERNet'
,
'demo.yaml'
))
self
.
mfr_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
MFR
,
mfr_weight_dir
=
mfr_weight_dir
,
mfr_cfg_path
=
mfr_cfg_path
,
device
=
self
.
device
device
=
self
.
device
,
)
# 初始化layout模型
...
...
@@ -104,16 +125,28 @@ class CustomPEKModel:
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Layout
,
layout_model_name
=
MODEL_NAME
.
LAYOUTLMv3
,
layout_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
layout_model_name
])),
layout_config_file
=
str
(
os
.
path
.
join
(
model_config_dir
,
"layoutlmv3"
,
"layoutlmv3_base_inference.yaml"
)),
device
=
self
.
device
layout_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
layout_model_name
]
)
),
layout_config_file
=
str
(
os
.
path
.
join
(
model_config_dir
,
'layoutlmv3'
,
'layoutlmv3_base_inference.yaml'
)
),
device
=
self
.
device
,
)
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Layout
,
layout_model_name
=
MODEL_NAME
.
DocLayout_YOLO
,
doclayout_yolo_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
layout_model_name
])),
device
=
self
.
device
doclayout_yolo_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
layout_model_name
]
)
),
device
=
self
.
device
,
)
# 初始化ocr
if
self
.
apply_ocr
:
...
...
@@ -121,23 +154,22 @@ class CustomPEKModel:
atom_model_name
=
AtomicModel
.
OCR
,
ocr_show_log
=
show_log
,
det_db_box_thresh
=
0.3
,
lang
=
self
.
lang
lang
=
self
.
lang
,
)
# init table model
if
self
.
apply_table
:
table_model_dir
=
self
.
configs
[
"
weights
"
][
self
.
table_model_name
]
table_model_dir
=
self
.
configs
[
'
weights
'
][
self
.
table_model_name
]
self
.
table_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Table
,
table_model_name
=
self
.
table_model_name
,
table_model_path
=
str
(
os
.
path
.
join
(
models_dir
,
table_model_dir
)),
table_max_time
=
self
.
table_max_time
,
device
=
self
.
device
device
=
self
.
device
,
)
logger
.
info
(
'DocAnalysis init done!'
)
def
__call__
(
self
,
image
):
page_start
=
time
.
time
()
# layout检测
...
...
@@ -150,7 +182,7 @@ class CustomPEKModel:
# doclayout_yolo
layout_res
=
self
.
layout_model
.
predict
(
image
)
layout_cost
=
round
(
time
.
time
()
-
layout_start
,
2
)
logger
.
info
(
f
"
layout detection time:
{
layout_cost
}
"
)
logger
.
info
(
f
'
layout detection time:
{
layout_cost
}
'
)
pil_img
=
Image
.
fromarray
(
image
)
...
...
@@ -158,32 +190,40 @@ class CustomPEKModel:
# 公式检测
mfd_start
=
time
.
time
()
mfd_res
=
self
.
mfd_model
.
predict
(
image
)
logger
.
info
(
f
"
mfd time:
{
round
(
time
.
time
()
-
mfd_start
,
2
)
}
"
)
logger
.
info
(
f
'
mfd time:
{
round
(
time
.
time
()
-
mfd_start
,
2
)
}
'
)
# 公式识别
mfr_start
=
time
.
time
()
formula_list
=
self
.
mfr_model
.
predict
(
mfd_res
,
image
)
layout_res
.
extend
(
formula_list
)
mfr_cost
=
round
(
time
.
time
()
-
mfr_start
,
2
)
logger
.
info
(
f
"
formula nums:
{
len
(
formula_list
)
}
, mfr time:
{
mfr_cost
}
"
)
logger
.
info
(
f
'
formula nums:
{
len
(
formula_list
)
}
, mfr time:
{
mfr_cost
}
'
)
# 清理显存
clean_vram
(
self
.
device
,
vram_threshold
=
8
)
# 从layout_res中获取ocr区域、表格区域、公式区域
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
=
get_res_list_from_layout_res
(
layout_res
)
ocr_res_list
,
table_res_list
,
single_page_mfdetrec_res
=
(
get_res_list_from_layout_res
(
layout_res
)
)
# ocr识别
if
self
.
apply_ocr
:
ocr_start
=
time
.
time
()
# Process each area that requires OCR processing
for
res
in
ocr_res_list
:
new_image
,
useful_list
=
crop_img
(
res
,
pil_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
new_image
,
useful_list
=
crop_img
(
res
,
pil_img
,
crop_paste_x
=
50
,
crop_paste_y
=
50
)
adjusted_mfdetrec_res
=
get_adjusted_mfdetrec_res
(
single_page_mfdetrec_res
,
useful_list
)
# OCR recognition
new_image
=
cv2
.
cvtColor
(
np
.
asarray
(
new_image
),
cv2
.
COLOR_RGB2BGR
)
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
ocr_res
=
self
.
ocr_model
.
ocr
(
new_image
,
mfd_res
=
adjusted_mfdetrec_res
)[
0
]
# Integration results
if
ocr_res
:
...
...
@@ -191,7 +231,7 @@ class CustomPEKModel:
layout_res
.
extend
(
ocr_result_list
)
ocr_cost
=
round
(
time
.
time
()
-
ocr_start
,
2
)
logger
.
info
(
f
"
ocr time:
{
ocr_cost
}
"
)
logger
.
info
(
f
'
ocr time:
{
ocr_cost
}
'
)
# 表格识别 table recognition
if
self
.
apply_table
:
...
...
@@ -202,27 +242,37 @@ class CustomPEKModel:
html_code
=
None
if
self
.
table_model_name
==
MODEL_NAME
.
STRUCT_EQTABLE
:
with
torch
.
no_grad
():
table_result
=
self
.
table_model
.
predict
(
new_image
,
"
html
"
)
table_result
=
self
.
table_model
.
predict
(
new_image
,
'
html
'
)
if
len
(
table_result
)
>
0
:
html_code
=
table_result
[
0
]
elif
self
.
table_model_name
==
MODEL_NAME
.
TABLE_MASTER
:
html_code
=
self
.
table_model
.
img2html
(
new_image
)
elif
self
.
table_model_name
==
MODEL_NAME
.
RAPID_TABLE
:
html_code
,
table_cell_bboxes
,
elapse
=
self
.
table_model
.
predict
(
new_image
)
html_code
,
table_cell_bboxes
,
elapse
=
self
.
table_model
.
predict
(
new_image
)
run_time
=
time
.
time
()
-
single_table_start_time
if
run_time
>
self
.
table_max_time
:
logger
.
warning
(
f
"table recognition processing exceeds max time
{
self
.
table_max_time
}
s"
)
logger
.
warning
(
f
'table recognition processing exceeds max time
{
self
.
table_max_time
}
s'
)
# 判断是否返回正常
if
html_code
:
expected_ending
=
html_code
.
strip
().
endswith
(
'</html>'
)
or
html_code
.
strip
().
endswith
(
'</table>'
)
expected_ending
=
html_code
.
strip
().
endswith
(
'</html>'
)
or
html_code
.
strip
().
endswith
(
'</table>'
)
if
expected_ending
:
res
[
"
html
"
]
=
html_code
res
[
'
html
'
]
=
html_code
else
:
logger
.
warning
(
f
"table recognition processing fails, not found expected HTML table end"
)
logger
.
warning
(
'table recognition processing fails, not found expected HTML table end'
)
else
:
logger
.
warning
(
f
"table recognition processing fails, not get html return"
)
logger
.
info
(
f
"table time:
{
round
(
time
.
time
()
-
table_start
,
2
)
}
"
)
logger
.
warning
(
'table recognition processing fails, not get html return'
)
logger
.
info
(
f
'table time:
{
round
(
time
.
time
()
-
table_start
,
2
)
}
'
)
logger
.
info
(
f
"
-----page total time:
{
round
(
time
.
time
()
-
page_start
,
2
)
}
-----
"
)
logger
.
info
(
f
'
-----page total time:
{
round
(
time
.
time
()
-
page_start
,
2
)
}
-----
'
)
return
layout_res
magic_pdf/model/sub_modules/model_init.py
View file @
ea2f8ea0
from
loguru
import
logger
from
magic_pdf.
libs.C
onstants
import
MODEL_NAME
from
magic_pdf.
config.c
onstants
import
MODEL_NAME
from
magic_pdf.model.model_list
import
AtomicModel
from
magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO
import
DocLayoutYOLOModel
from
magic_pdf.model.sub_modules.layout.layoutlmv3.model_init
import
Layoutlmv3_Predictor
from
magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO
import
\
DocLayoutYOLOModel
from
magic_pdf.model.sub_modules.layout.layoutlmv3.model_init
import
\
Layoutlmv3_Predictor
from
magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8
import
YOLOv8MFDModel
from
magic_pdf.model.sub_modules.mfr.unimernet.Unimernet
import
UnimernetModel
from
magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod
import
ModifiedPaddleOCR
from
magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod
import
\
ModifiedPaddleOCR
from
magic_pdf.model.sub_modules.table.rapidtable.rapid_table
import
\
RapidTableModel
# from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
from
magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable
import
StructTableModel
from
magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle
import
TableMasterPaddleModel
from
magic_pdf.model.sub_modules.table.rapidtable.rapid_table
import
RapidTableModel
from
magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable
import
\
StructTableModel
from
magic_pdf.model.sub_modules.table.tablemaster.tablemaster_paddle
import
\
TableMasterPaddleModel
def
table_model_init
(
table_model_type
,
model_path
,
max_time
,
_device_
=
'cpu'
):
...
...
@@ -19,14 +24,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
table_model
=
StructTableModel
(
model_path
,
max_new_tokens
=
2048
,
max_time
=
max_time
)
elif
table_model_type
==
MODEL_NAME
.
TABLE_MASTER
:
config
=
{
"
model_dir
"
:
model_path
,
"
device
"
:
_device_
'
model_dir
'
:
model_path
,
'
device
'
:
_device_
}
table_model
=
TableMasterPaddleModel
(
config
)
elif
table_model_type
==
MODEL_NAME
.
RAPID_TABLE
:
table_model
=
RapidTableModel
()
else
:
logger
.
error
(
"
table model type not allow
"
)
logger
.
error
(
'
table model type not allow
'
)
exit
(
1
)
return
table_model
...
...
@@ -87,8 +92,8 @@ class AtomModelSingleton:
return
cls
.
_instance
def
get_atom_model
(
self
,
atom_model_name
:
str
,
**
kwargs
):
lang
=
kwargs
.
get
(
"
lang
"
,
None
)
layout_model_name
=
kwargs
.
get
(
"
layout_model_name
"
,
None
)
lang
=
kwargs
.
get
(
'
lang
'
,
None
)
layout_model_name
=
kwargs
.
get
(
'
layout_model_name
'
,
None
)
key
=
(
atom_model_name
,
layout_model_name
,
lang
)
if
key
not
in
self
.
_models
:
self
.
_models
[
key
]
=
atom_model_init
(
model_name
=
atom_model_name
,
**
kwargs
)
...
...
@@ -98,47 +103,47 @@ class AtomModelSingleton:
def
atom_model_init
(
model_name
:
str
,
**
kwargs
):
atom_model
=
None
if
model_name
==
AtomicModel
.
Layout
:
if
kwargs
.
get
(
"
layout_model_name
"
)
==
MODEL_NAME
.
LAYOUTLMv3
:
if
kwargs
.
get
(
'
layout_model_name
'
)
==
MODEL_NAME
.
LAYOUTLMv3
:
atom_model
=
layout_model_init
(
kwargs
.
get
(
"
layout_weights
"
),
kwargs
.
get
(
"
layout_config_file
"
),
kwargs
.
get
(
"
device
"
)
kwargs
.
get
(
'
layout_weights
'
),
kwargs
.
get
(
'
layout_config_file
'
),
kwargs
.
get
(
'
device
'
)
)
elif
kwargs
.
get
(
"
layout_model_name
"
)
==
MODEL_NAME
.
DocLayout_YOLO
:
elif
kwargs
.
get
(
'
layout_model_name
'
)
==
MODEL_NAME
.
DocLayout_YOLO
:
atom_model
=
doclayout_yolo_model_init
(
kwargs
.
get
(
"
doclayout_yolo_weights
"
),
kwargs
.
get
(
"
device
"
)
kwargs
.
get
(
'
doclayout_yolo_weights
'
),
kwargs
.
get
(
'
device
'
)
)
elif
model_name
==
AtomicModel
.
MFD
:
atom_model
=
mfd_model_init
(
kwargs
.
get
(
"
mfd_weights
"
),
kwargs
.
get
(
"
device
"
)
kwargs
.
get
(
'
mfd_weights
'
),
kwargs
.
get
(
'
device
'
)
)
elif
model_name
==
AtomicModel
.
MFR
:
atom_model
=
mfr_model_init
(
kwargs
.
get
(
"
mfr_weight_dir
"
),
kwargs
.
get
(
"
mfr_cfg_path
"
),
kwargs
.
get
(
"
device
"
)
kwargs
.
get
(
'
mfr_weight_dir
'
),
kwargs
.
get
(
'
mfr_cfg_path
'
),
kwargs
.
get
(
'
device
'
)
)
elif
model_name
==
AtomicModel
.
OCR
:
atom_model
=
ocr_model_init
(
kwargs
.
get
(
"
ocr_show_log
"
),
kwargs
.
get
(
"
det_db_box_thresh
"
),
kwargs
.
get
(
"
lang
"
)
kwargs
.
get
(
'
ocr_show_log
'
),
kwargs
.
get
(
'
det_db_box_thresh
'
),
kwargs
.
get
(
'
lang
'
)
)
elif
model_name
==
AtomicModel
.
Table
:
atom_model
=
table_model_init
(
kwargs
.
get
(
"
table_model_name
"
),
kwargs
.
get
(
"
table_model_path
"
),
kwargs
.
get
(
"
table_max_time
"
),
kwargs
.
get
(
"
device
"
)
kwargs
.
get
(
'
table_model_name
'
),
kwargs
.
get
(
'
table_model_path
'
),
kwargs
.
get
(
'
table_max_time
'
),
kwargs
.
get
(
'
device
'
)
)
else
:
logger
.
error
(
"
model name not allow
"
)
logger
.
error
(
'
model name not allow
'
)
exit
(
1
)
if
atom_model
is
None
:
logger
.
error
(
"
model init failed
"
)
logger
.
error
(
'
model init failed
'
)
exit
(
1
)
else
:
return
atom_model
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment