Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
c1ba9dcb
Unverified
Commit
c1ba9dcb
authored
Oct 23, 2024
by
Xiaomeng Zhao
Committed by
GitHub
Oct 23, 2024
Browse files
Merge pull request #773 from myhloli/add-doclayout-yolo
feat(model): add support for DocLayout-YOLO model
parents
efb5851f
1279f2cd
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
365 additions
and
130 deletions
+365
-130
magic-pdf.template.json
magic-pdf.template.json
+1
-1
magic_pdf/libs/Constants.py
magic_pdf/libs/Constants.py
+13
-6
magic_pdf/libs/boxbase.py
magic_pdf/libs/boxbase.py
+35
-0
magic_pdf/libs/config_reader.py
magic_pdf/libs/config_reader.py
+22
-1
magic_pdf/model/doc_analyze_by_custom_model.py
magic_pdf/model/doc_analyze_by_custom_model.py
+38
-14
magic_pdf/model/pdf_extract_kit.py
magic_pdf/model/pdf_extract_kit.py
+82
-43
magic_pdf/model/ppTableModel.py
magic_pdf/model/ppTableModel.py
+2
-2
magic_pdf/pipe/AbsPipe.py
magic_pdf/pipe/AbsPipe.py
+4
-1
magic_pdf/pipe/OCRPipe.py
magic_pdf/pipe/OCRPipe.py
+8
-4
magic_pdf/pipe/TXTPipe.py
magic_pdf/pipe/TXTPipe.py
+8
-4
magic_pdf/pipe/UNIPipe.py
magic_pdf/pipe/UNIPipe.py
+10
-5
magic_pdf/pre_proc/ocr_detect_all_bboxes.py
magic_pdf/pre_proc/ocr_detect_all_bboxes.py
+25
-3
magic_pdf/resources/model_config/model_configs.yaml
magic_pdf/resources/model_config/model_configs.yaml
+5
-13
magic_pdf/tools/common.py
magic_pdf/tools/common.py
+9
-4
magic_pdf/user_api.py
magic_pdf/user_api.py
+13
-5
old_docs/download_models.py
old_docs/download_models.py
+21
-8
old_docs/download_models_hf.py
old_docs/download_models_hf.py
+29
-9
projects/gradio_app/app.py
projects/gradio_app/app.py
+40
-7
No files found.
magic-pdf.template.json
View file @
c1ba9dcb
...
...
@@ -7,7 +7,7 @@
"layoutreader-model-dir"
:
"/tmp/layoutreader"
,
"device-mode"
:
"cpu"
,
"layout-config"
:
{
"model"
:
"
doc
layout
_yolo
"
"model"
:
"layout
lmv3
"
},
"formula-config"
:
{
"mfd_model"
:
"yolo_v8_mfd"
,
...
...
magic_pdf/libs/Constants.py
View file @
c1ba9dcb
...
...
@@ -10,18 +10,12 @@ block维度自定义字段
# block中lines是否被删除
LINES_DELETED
=
"lines_deleted"
# struct eqtable
STRUCT_EQTABLE
=
"struct_eqtable"
# table recognition max time default value
TABLE_MAX_TIME_VALUE
=
400
# pp_table_result_max_length
TABLE_MAX_LEN
=
480
# pp table structure algorithm
TABLE_MASTER
=
"TableMaster"
# table master structure dict
TABLE_MASTER_DICT
=
"table_master_structure_dict.txt"
...
...
@@ -38,3 +32,16 @@ REC_MODEL_DIR = "ch_PP-OCRv3_rec_infer"
REC_CHAR_DICT
=
"ppocr_keys_v1.txt"
class
MODEL_NAME
:
# pp table structure algorithm
TABLE_MASTER
=
"tablemaster"
# struct eqtable
STRUCT_EQTABLE
=
"struct_eqtable"
DocLayout_YOLO
=
"doclayout_yolo"
LAYOUTLMv3
=
"layoutlmv3"
YOLO_V8_MFD
=
"yolo_v8_mfd"
UniMerNet_v2_Small
=
"unimernet_small"
\ No newline at end of file
magic_pdf/libs/boxbase.py
View file @
c1ba9dcb
...
...
@@ -445,3 +445,38 @@ def get_overlap_area(bbox1, bbox2):
# The area of overlap area
return
(
x_right
-
x_left
)
*
(
y_bottom
-
y_top
)
def
calculate_vertical_projection_overlap_ratio
(
block1
,
block2
):
"""
Calculate the proportion of the x-axis covered by the vertical projection of two blocks.
Args:
block1 (tuple): Coordinates of the first block (x0, y0, x1, y1).
block2 (tuple): Coordinates of the second block (x0, y0, x1, y1).
Returns:
float: The proportion of the x-axis covered by the vertical projection of the two blocks.
"""
x0_1
,
_
,
x1_1
,
_
=
block1
x0_2
,
_
,
x1_2
,
_
=
block2
# Calculate the intersection of the x-coordinates
x_left
=
max
(
x0_1
,
x0_2
)
x_right
=
min
(
x1_1
,
x1_2
)
if
x_right
<
x_left
:
return
0.0
# Length of the intersection
intersection_length
=
x_right
-
x_left
# Length of the x-axis projection of the first block
block1_length
=
x1_1
-
x0_1
if
block1_length
==
0
:
return
0.0
# Proportion of the x-axis covered by the intersection
# logger.info(f"intersection_length: {intersection_length}, block1_length: {block1_length}")
return
intersection_length
/
block1_length
magic_pdf/libs/config_reader.py
View file @
c1ba9dcb
...
...
@@ -8,6 +8,7 @@ import os
from
loguru
import
logger
from
magic_pdf.libs.Constants
import
MODEL_NAME
from
magic_pdf.libs.commons
import
parse_bucket_key
# 定义配置文件名常量
...
...
@@ -94,10 +95,30 @@ def get_table_recog_config():
table_config
=
config
.
get
(
"table-config"
)
if
table_config
is
None
:
logger
.
warning
(
f
"'table-config' not found in
{
CONFIG_FILE_NAME
}
, use 'False' as default"
)
return
json
.
loads
(
'{"is_table_recog_
enable": false, "max_time": 400}'
)
return
json
.
loads
(
f
'{{"model": "
{
MODEL_NAME
.
TABLE_MASTER
}
","
enable": false, "max_time": 400}
}
'
)
else
:
return
table_config
def
get_layout_config
():
config
=
read_config
()
layout_config
=
config
.
get
(
"layout-config"
)
if
layout_config
is
None
:
logger
.
warning
(
f
"'layout-config' not found in
{
CONFIG_FILE_NAME
}
, use '
{
MODEL_NAME
.
LAYOUTLMv3
}
' as default"
)
return
json
.
loads
(
f
'{{"model": "
{
MODEL_NAME
.
LAYOUTLMv3
}
"}}'
)
else
:
return
layout_config
def
get_formula_config
():
config
=
read_config
()
formula_config
=
config
.
get
(
"formula-config"
)
if
formula_config
is
None
:
logger
.
warning
(
f
"'formula-config' not found in
{
CONFIG_FILE_NAME
}
, use 'True' as default"
)
return
json
.
loads
(
f
'{{"mfd_model": "
{
MODEL_NAME
.
YOLO_V8_MFD
}
","mfr_model": "
{
MODEL_NAME
.
UniMerNet_v2_Small
}
","enable": true}}'
)
else
:
return
formula_config
if
__name__
==
"__main__"
:
ak
,
sk
,
endpoint
=
get_s3_config
(
"llm-raw"
)
magic_pdf/model/doc_analyze_by_custom_model.py
View file @
c1ba9dcb
...
...
@@ -5,7 +5,8 @@ import numpy as np
from
loguru
import
logger
from
magic_pdf.libs.clean_memory
import
clean_memory
from
magic_pdf.libs.config_reader
import
get_local_models_dir
,
get_device
,
get_table_recog_config
from
magic_pdf.libs.config_reader
import
get_local_models_dir
,
get_device
,
get_table_recog_config
,
get_layout_config
,
\
get_formula_config
from
magic_pdf.model.model_list
import
MODEL
import
magic_pdf.model
as
model_config
...
...
@@ -68,14 +69,17 @@ class ModelSingleton:
cls
.
_instance
=
super
().
__new__
(
cls
)
return
cls
.
_instance
def
get_model
(
self
,
ocr
:
bool
,
show_log
:
bool
,
lang
=
None
):
key
=
(
ocr
,
show_log
,
lang
)
def
get_model
(
self
,
ocr
:
bool
,
show_log
:
bool
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
key
=
(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
if
key
not
in
self
.
_models
:
self
.
_models
[
key
]
=
custom_model_init
(
ocr
=
ocr
,
show_log
=
show_log
,
lang
=
lang
)
self
.
_models
[
key
]
=
custom_model_init
(
ocr
=
ocr
,
show_log
=
show_log
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
return
self
.
_models
[
key
]
def
custom_model_init
(
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
lang
=
None
):
def
custom_model_init
(
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
model
=
None
if
model_config
.
__model_mode__
==
"lite"
:
...
...
@@ -95,14 +99,30 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
# 从配置文件读取model-dir和device
local_models_dir
=
get_local_models_dir
()
device
=
get_device
()
layout_config
=
get_layout_config
()
if
layout_model
is
not
None
:
layout_config
[
"model"
]
=
layout_model
formula_config
=
get_formula_config
()
if
formula_enable
is
not
None
:
formula_config
[
"enable"
]
=
formula_enable
table_config
=
get_table_recog_config
()
model_input
=
{
"ocr"
:
ocr
,
"show_log"
:
show_log
,
"models_dir"
:
local_models_dir
,
"device"
:
device
,
"table_config"
:
table_config
,
"lang"
:
lang
,
}
if
table_enable
is
not
None
:
table_config
[
"enable"
]
=
table_enable
model_input
=
{
"ocr"
:
ocr
,
"show_log"
:
show_log
,
"models_dir"
:
local_models_dir
,
"device"
:
device
,
"table_config"
:
table_config
,
"layout_config"
:
layout_config
,
"formula_config"
:
formula_config
,
"lang"
:
lang
,
}
custom_model
=
CustomPEKModel
(
**
model_input
)
else
:
logger
.
error
(
"Not allow model_name!"
)
...
...
@@ -117,10 +137,14 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
def
doc_analyze
(
pdf_bytes
:
bytes
,
ocr
:
bool
=
False
,
show_log
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
if
lang
==
""
:
lang
=
None
model_manager
=
ModelSingleton
()
custom_model
=
model_manager
.
get_model
(
ocr
,
show_log
,
lang
)
custom_model
=
model_manager
.
get_model
(
ocr
,
show_log
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
with
fitz
.
open
(
"pdf"
,
pdf_bytes
)
as
doc
:
pdf_page_num
=
doc
.
page_count
...
...
magic_pdf/model/pdf_extract_kit.py
View file @
c1ba9dcb
...
...
@@ -25,6 +25,7 @@ try:
from
unimernet.common.config
import
Config
import
unimernet.tasks
as
tasks
from
unimernet.processors
import
load_processor
from
doclayout_yolo
import
YOLOv10
except
ImportError
as
e
:
logger
.
exception
(
e
)
...
...
@@ -41,7 +42,7 @@ from magic_pdf.model.ppTableModel import ppTableModel
def
table_model_init
(
table_model_type
,
model_path
,
max_time
,
_device_
=
'cpu'
):
if
table_model_type
==
STRUCT_EQTABLE
:
if
table_model_type
==
MODEL_NAME
.
STRUCT_EQTABLE
:
table_model
=
StructTableModel
(
model_path
,
max_time
=
max_time
,
device
=
_device_
)
else
:
config
=
{
...
...
@@ -77,6 +78,11 @@ def layout_model_init(weight, config_file, device):
return
model
def
doclayout_yolo_model_init
(
weight
):
model
=
YOLOv10
(
weight
)
return
model
def
ocr_model_init
(
show_log
:
bool
=
False
,
det_db_box_thresh
=
0.3
,
lang
=
None
,
use_dilation
=
True
,
det_db_unclip_ratio
=
2.4
):
if
lang
is
not
None
:
model
=
ModifiedPaddleOCR
(
show_log
=
show_log
,
det_db_box_thresh
=
det_db_box_thresh
,
lang
=
lang
,
use_dilation
=
use_dilation
,
det_db_unclip_ratio
=
det_db_unclip_ratio
)
...
...
@@ -114,19 +120,27 @@ class AtomModelSingleton:
return
cls
.
_instance
def
get_atom_model
(
self
,
atom_model_name
:
str
,
**
kwargs
):
if
atom_model_name
not
in
self
.
_models
:
self
.
_models
[
atom_model_name
]
=
atom_model_init
(
model_name
=
atom_model_name
,
**
kwargs
)
return
self
.
_models
[
atom_model_name
]
lang
=
kwargs
.
get
(
"lang"
,
None
)
layout_model_name
=
kwargs
.
get
(
"layout_model_name"
,
None
)
key
=
(
atom_model_name
,
layout_model_name
,
lang
)
if
key
not
in
self
.
_models
:
self
.
_models
[
key
]
=
atom_model_init
(
model_name
=
atom_model_name
,
**
kwargs
)
return
self
.
_models
[
key
]
def
atom_model_init
(
model_name
:
str
,
**
kwargs
):
if
model_name
==
AtomicModel
.
Layout
:
atom_model
=
layout_model_init
(
kwargs
.
get
(
"layout_weights"
),
kwargs
.
get
(
"layout_config_file"
),
kwargs
.
get
(
"device"
)
)
if
kwargs
.
get
(
"layout_model_name"
)
==
MODEL_NAME
.
LAYOUTLMv3
:
atom_model
=
layout_model_init
(
kwargs
.
get
(
"layout_weights"
),
kwargs
.
get
(
"layout_config_file"
),
kwargs
.
get
(
"device"
)
)
elif
kwargs
.
get
(
"layout_model_name"
)
==
MODEL_NAME
.
DocLayout_YOLO
:
atom_model
=
doclayout_yolo_model_init
(
kwargs
.
get
(
"doclayout_yolo_weights"
),
)
elif
model_name
==
AtomicModel
.
MFD
:
atom_model
=
mfd_model_init
(
kwargs
.
get
(
"mfd_weights"
)
...
...
@@ -145,7 +159,7 @@ def atom_model_init(model_name: str, **kwargs):
)
elif
model_name
==
AtomicModel
.
Table
:
atom_model
=
table_model_init
(
kwargs
.
get
(
"table_model_
typ
e"
),
kwargs
.
get
(
"table_model_
nam
e"
),
kwargs
.
get
(
"table_model_path"
),
kwargs
.
get
(
"table_max_time"
),
kwargs
.
get
(
"device"
)
...
...
@@ -193,23 +207,35 @@ class CustomPEKModel:
with
open
(
config_path
,
"r"
,
encoding
=
'utf-8'
)
as
f
:
self
.
configs
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
# 初始化解析配置
self
.
apply_layout
=
kwargs
.
get
(
"apply_layout"
,
self
.
configs
[
"config"
][
"layout"
])
self
.
apply_formula
=
kwargs
.
get
(
"apply_formula"
,
self
.
configs
[
"config"
][
"formula"
])
# layout config
self
.
layout_config
=
kwargs
.
get
(
"layout_config"
)
self
.
layout_model_name
=
self
.
layout_config
.
get
(
"model"
,
MODEL_NAME
.
DocLayout_YOLO
)
# formula config
self
.
formula_config
=
kwargs
.
get
(
"formula_config"
)
self
.
mfd_model_name
=
self
.
formula_config
.
get
(
"mfd_model"
,
MODEL_NAME
.
YOLO_V8_MFD
)
self
.
mfr_model_name
=
self
.
formula_config
.
get
(
"mfr_model"
,
MODEL_NAME
.
UniMerNet_v2_Small
)
self
.
apply_formula
=
self
.
formula_config
.
get
(
"enable"
,
True
)
# table config
self
.
table_config
=
kwargs
.
get
(
"table_config"
,
self
.
configs
[
"config"
][
"table_config"
]
)
self
.
apply_table
=
self
.
table_config
.
get
(
"
is_table_recog_
enable"
,
False
)
self
.
table_config
=
kwargs
.
get
(
"table_config"
)
self
.
apply_table
=
self
.
table_config
.
get
(
"enable"
,
False
)
self
.
table_max_time
=
self
.
table_config
.
get
(
"max_time"
,
TABLE_MAX_TIME_VALUE
)
self
.
table_model_type
=
self
.
table_config
.
get
(
"model"
,
TABLE_MASTER
)
self
.
table_model_name
=
self
.
table_config
.
get
(
"model"
,
MODEL_NAME
.
TABLE_MASTER
)
# ocr config
self
.
apply_ocr
=
ocr
self
.
lang
=
kwargs
.
get
(
"lang"
,
None
)
logger
.
info
(
"DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}, lang: {}"
.
format
(
self
.
apply_layout
,
self
.
apply_formula
,
self
.
apply_ocr
,
self
.
apply_table
,
self
.
lang
"DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
"apply_table: {}, table_model: {}, lang: {}"
.
format
(
self
.
layout_model_name
,
self
.
apply_formula
,
self
.
apply_ocr
,
self
.
apply_table
,
self
.
table_model_name
,
self
.
lang
)
)
assert
self
.
apply_layout
,
"DocAnalysis must contain layout model."
# 初始化解析方案
self
.
device
=
kwargs
.
get
(
"device"
,
self
.
configs
[
"config"
][
"device"
]
)
self
.
device
=
kwargs
.
get
(
"device"
,
"cpu"
)
logger
.
info
(
"using device: {}"
.
format
(
self
.
device
))
models_dir
=
kwargs
.
get
(
"models_dir"
,
os
.
path
.
join
(
root_dir
,
"resources"
,
"models"
))
logger
.
info
(
"using models_dir: {}"
.
format
(
models_dir
))
...
...
@@ -218,17 +244,16 @@ class CustomPEKModel:
# 初始化公式识别
if
self
.
apply_formula
:
# 初始化公式检测模型
# self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
self
.
mfd_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
MFD
,
mfd_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
"mfd"
]))
mfd_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
self
.
mfd_model_name
]))
)
# 初始化公式解析模型
mfr_weight_dir
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
"mfr"
]))
mfr_weight_dir
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
"weights"
][
self
.
mfr_model_name
]))
mfr_cfg_path
=
str
(
os
.
path
.
join
(
model_config_dir
,
"UniMERNet"
,
"demo.yaml"
))
# self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
# self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
self
.
mfr_model
,
self
.
mfr_transform
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
MFR
,
mfr_weight_dir
=
mfr_weight_dir
,
...
...
@@ -237,17 +262,20 @@ class CustomPEKModel:
)
# 初始化layout模型
# self.layout_model = Layoutlmv3_Predictor(
# str(os.path.join(models_dir, self.configs['weights']['layout'])),
# str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
# device=self.device
# )
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Layout
,
layout_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
'layout'
])),
layout_config_file
=
str
(
os
.
path
.
join
(
model_config_dir
,
"layoutlmv3"
,
"layoutlmv3_base_inference.yaml"
)),
device
=
self
.
device
)
if
self
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Layout
,
layout_model_name
=
MODEL_NAME
.
LAYOUTLMv3
,
layout_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
layout_model_name
])),
layout_config_file
=
str
(
os
.
path
.
join
(
model_config_dir
,
"layoutlmv3"
,
"layoutlmv3_base_inference.yaml"
)),
device
=
self
.
device
)
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Layout
,
layout_model_name
=
MODEL_NAME
.
DocLayout_YOLO
,
doclayout_yolo_weights
=
str
(
os
.
path
.
join
(
models_dir
,
self
.
configs
[
'weights'
][
self
.
layout_model_name
]))
)
# 初始化ocr
if
self
.
apply_ocr
:
...
...
@@ -260,12 +288,10 @@ class CustomPEKModel:
)
# init table model
if
self
.
apply_table
:
table_model_dir
=
self
.
configs
[
"weights"
][
self
.
table_model_type
]
# self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
# max_time=self.table_max_time, _device_=self.device)
table_model_dir
=
self
.
configs
[
"weights"
][
self
.
table_model_name
]
self
.
table_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Table
,
table_model_
typ
e
=
self
.
table_model_
typ
e
,
table_model_
nam
e
=
self
.
table_model_
nam
e
,
table_model_path
=
str
(
os
.
path
.
join
(
models_dir
,
table_model_dir
)),
table_max_time
=
self
.
table_max_time
,
device
=
self
.
device
...
...
@@ -282,7 +308,21 @@ class CustomPEKModel:
# layout检测
layout_start
=
time
.
time
()
layout_res
=
self
.
layout_model
(
image
,
ignore_catids
=
[])
if
self
.
layout_model_name
==
MODEL_NAME
.
LAYOUTLMv3
:
# layoutlmv3
layout_res
=
self
.
layout_model
(
image
,
ignore_catids
=
[])
elif
self
.
layout_model_name
==
MODEL_NAME
.
DocLayout_YOLO
:
# doclayout_yolo
layout_res
=
[]
doclayout_yolo_res
=
self
.
layout_model
.
predict
(
image
,
imgsz
=
1024
,
conf
=
0.15
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
)[
0
]
for
xyxy
,
conf
,
cla
in
zip
(
doclayout_yolo_res
.
boxes
.
xyxy
.
cpu
(),
doclayout_yolo_res
.
boxes
.
conf
.
cpu
(),
doclayout_yolo_res
.
boxes
.
cls
.
cpu
()):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
new_item
=
{
'category_id'
:
int
(
cla
.
item
()),
'poly'
:
[
xmin
,
ymin
,
xmax
,
ymin
,
xmax
,
ymax
,
xmin
,
ymax
],
'score'
:
round
(
float
(
conf
.
item
()),
3
),
}
layout_res
.
append
(
new_item
)
layout_cost
=
round
(
time
.
time
()
-
layout_start
,
2
)
logger
.
info
(
f
"layout detection time:
{
layout_cost
}
"
)
...
...
@@ -291,7 +331,7 @@ class CustomPEKModel:
if
self
.
apply_formula
:
# 公式检测
mfd_start
=
time
.
time
()
mfd_res
=
self
.
mfd_model
.
predict
(
image
,
imgsz
=
1888
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
)[
0
]
mfd_res
=
self
.
mfd_model
.
predict
(
image
,
imgsz
=
1888
,
conf
=
0.25
,
iou
=
0.45
,
verbose
=
True
,
device
=
self
.
device
)[
0
]
logger
.
info
(
f
"mfd time:
{
round
(
time
.
time
()
-
mfd_start
,
2
)
}
"
)
for
xyxy
,
conf
,
cla
in
zip
(
mfd_res
.
boxes
.
xyxy
.
cpu
(),
mfd_res
.
boxes
.
conf
.
cpu
(),
mfd_res
.
boxes
.
cls
.
cpu
()):
xmin
,
ymin
,
xmax
,
ymax
=
[
int
(
p
.
item
())
for
p
in
xyxy
]
...
...
@@ -303,7 +343,6 @@ class CustomPEKModel:
}
layout_res
.
append
(
new_item
)
latex_filling_list
.
append
(
new_item
)
# bbox_img = get_croped_image(pil_img, [xmin, ymin, xmax, ymax])
bbox_img
=
pil_img
.
crop
((
xmin
,
ymin
,
xmax
,
ymax
))
mf_image_list
.
append
(
bbox_img
)
...
...
@@ -405,7 +444,7 @@ class CustomPEKModel:
# logger.info("------------------table recognition processing begins-----------------")
latex_code
=
None
html_code
=
None
if
self
.
table_model_
typ
e
==
STRUCT_EQTABLE
:
if
self
.
table_model_
nam
e
==
MODEL_NAME
.
STRUCT_EQTABLE
:
with
torch
.
no_grad
():
latex_code
=
self
.
table_model
.
image2latex
(
new_image
)[
0
]
else
:
...
...
magic_pdf/model/ppTableModel.py
View file @
c1ba9dcb
...
...
@@ -52,11 +52,11 @@ class ppTableModel(object):
rec_model_dir
=
os
.
path
.
join
(
model_dir
,
REC_MODEL_DIR
)
rec_char_dict_path
=
os
.
path
.
join
(
model_dir
,
REC_CHAR_DICT
)
device
=
kwargs
.
get
(
"device"
,
"cpu"
)
use_gpu
=
True
if
device
==
"cuda"
else
False
use_gpu
=
True
if
device
.
startswith
(
"cuda"
)
else
False
config
=
{
"use_gpu"
:
use_gpu
,
"table_max_len"
:
kwargs
.
get
(
"table_max_len"
,
TABLE_MAX_LEN
),
"table_algorithm"
:
TABLE_MASTER
,
"table_algorithm"
:
"TableMaster"
,
"table_model_dir"
:
table_model_dir
,
"table_char_dict_path"
:
table_char_dict_path
,
"det_model_dir"
:
det_model_dir
,
...
...
magic_pdf/pipe/AbsPipe.py
View file @
c1ba9dcb
...
...
@@ -17,7 +17,7 @@ class AbsPipe(ABC):
PIP_TXT
=
"txt"
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
self
.
pdf_bytes
=
pdf_bytes
self
.
model_list
=
model_list
self
.
image_writer
=
image_writer
...
...
@@ -26,6 +26,9 @@ class AbsPipe(ABC):
self
.
start_page_id
=
start_page_id
self
.
end_page_id
=
end_page_id
self
.
lang
=
lang
self
.
layout_model
=
layout_model
self
.
formula_enable
=
formula_enable
self
.
table_enable
=
table_enable
def
get_compress_pdf_mid_data
(
self
):
return
JsonCompressor
.
compress_json
(
self
.
pdf_mid_data
)
...
...
magic_pdf/pipe/OCRPipe.py
View file @
c1ba9dcb
...
...
@@ -10,8 +10,10 @@ from magic_pdf.user_api import parse_ocr_pdf
class
OCRPipe
(
AbsPipe
):
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
super
().
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
)
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
super
().
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
def
pipe_classify
(
self
):
pass
...
...
@@ -19,12 +21,14 @@ class OCRPipe(AbsPipe):
def
pipe_analyze
(
self
):
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
True
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_parse
(
self
):
self
.
pdf_mid_data
=
parse_ocr_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
result
=
super
().
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
...
...
magic_pdf/pipe/TXTPipe.py
View file @
c1ba9dcb
...
...
@@ -11,8 +11,10 @@ from magic_pdf.user_api import parse_txt_pdf
class
TXTPipe
(
AbsPipe
):
def
__init__
(
self
,
pdf_bytes
:
bytes
,
model_list
:
list
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
super
().
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
)
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
super
().
__init__
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
def
pipe_classify
(
self
):
pass
...
...
@@ -20,12 +22,14 @@ class TXTPipe(AbsPipe):
def
pipe_analyze
(
self
):
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
False
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_parse
(
self
):
self
.
pdf_mid_data
=
parse_txt_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_mk_uni_format
(
self
,
img_parent_path
:
str
,
drop_mode
=
DropMode
.
WHOLE_PDF
):
result
=
super
().
pipe_mk_uni_format
(
img_parent_path
,
drop_mode
)
...
...
magic_pdf/pipe/UNIPipe.py
View file @
c1ba9dcb
...
...
@@ -14,9 +14,11 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
class
UNIPipe
(
AbsPipe
):
def
__init__
(
self
,
pdf_bytes
:
bytes
,
jso_useful_key
:
dict
,
image_writer
:
AbsReaderWriter
,
is_debug
:
bool
=
False
,
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
):
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
):
self
.
pdf_type
=
jso_useful_key
[
"_pdf_type"
]
super
().
__init__
(
pdf_bytes
,
jso_useful_key
[
"model_list"
],
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
)
super
().
__init__
(
pdf_bytes
,
jso_useful_key
[
"model_list"
],
image_writer
,
is_debug
,
start_page_id
,
end_page_id
,
lang
,
layout_model
,
formula_enable
,
table_enable
)
if
len
(
self
.
model_list
)
==
0
:
self
.
input_model_is_empty
=
True
else
:
...
...
@@ -29,18 +31,21 @@ class UNIPipe(AbsPipe):
if
self
.
pdf_type
==
self
.
PIP_TXT
:
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
False
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
elif
self
.
pdf_type
==
self
.
PIP_OCR
:
self
.
model_list
=
doc_analyze
(
self
.
pdf_bytes
,
ocr
=
True
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
def
pipe_parse
(
self
):
if
self
.
pdf_type
==
self
.
PIP_TXT
:
self
.
pdf_mid_data
=
parse_union_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
input_model_is_empty
=
self
.
input_model_is_empty
,
start_page_id
=
self
.
start_page_id
,
end_page_id
=
self
.
end_page_id
,
lang
=
self
.
lang
)
lang
=
self
.
lang
,
layout_model
=
self
.
layout_model
,
formula_enable
=
self
.
formula_enable
,
table_enable
=
self
.
table_enable
)
elif
self
.
pdf_type
==
self
.
PIP_OCR
:
self
.
pdf_mid_data
=
parse_ocr_pdf
(
self
.
pdf_bytes
,
self
.
model_list
,
self
.
image_writer
,
is_debug
=
self
.
is_debug
,
...
...
magic_pdf/pre_proc/ocr_detect_all_bboxes.py
View file @
c1ba9dcb
from
loguru
import
logger
from
magic_pdf.libs.boxbase
import
get_minbox_if_overlap_by_ratio
,
calculate_overlap_area_in_bbox1_area_ratio
,
\
calculate_iou
calculate_iou
,
calculate_vertical_projection_overlap_ratio
from
magic_pdf.libs.drop_tag
import
DropTag
from
magic_pdf.libs.ocr_content_type
import
BlockType
from
magic_pdf.pre_proc.remove_bbox_overlap
import
remove_overlap_between_bbox_for_block
...
...
@@ -97,12 +97,20 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
# 通过后续大框套小框逻辑删除
'''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)'''
footnote_blocks
=
[]
for
discarded
in
discarded_blocks
:
x0
,
y0
,
x1
,
y1
=
discarded
[
'bbox'
]
all_discarded_blocks
.
append
([
x0
,
y0
,
x1
,
y1
,
None
,
None
,
None
,
BlockType
.
Discarded
,
None
,
None
,
None
,
None
,
discarded
[
"score"
]])
# 将footnote加入到all_bboxes中,用来计算layout
# if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
# all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Footnote, None, None, None, None, discarded["score"]])
if
(
x1
-
x0
)
>
(
page_w
/
3
)
and
(
y1
-
y0
)
>
10
and
y0
>
(
page_h
/
2
):
footnote_blocks
.
append
([
x0
,
y0
,
x1
,
y1
])
'''移除在footnote下面的任何框'''
need_remove_blocks
=
find_blocks_under_footnote
(
all_bboxes
,
footnote_blocks
)
if
len
(
need_remove_blocks
)
>
0
:
for
block
in
need_remove_blocks
:
all_bboxes
.
remove
(
block
)
all_discarded_blocks
.
append
(
block
)
'''经过以上处理后,还存在大框套小框的情况,则删除小框'''
all_bboxes
=
remove_overlaps_min_blocks
(
all_bboxes
)
...
...
@@ -113,6 +121,20 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
return
all_bboxes
,
all_discarded_blocks
def
find_blocks_under_footnote
(
all_bboxes
,
footnote_blocks
):
need_remove_blocks
=
[]
for
block
in
all_bboxes
:
block_x0
,
block_y0
,
block_x1
,
block_y1
=
block
[:
4
]
for
footnote_bbox
in
footnote_blocks
:
footnote_x0
,
footnote_y0
,
footnote_x1
,
footnote_y1
=
footnote_bbox
# 如果footnote的纵向投影覆盖了block的纵向投影的80%且block的y0大于等于footnote的y1
if
block_y0
>=
footnote_y1
and
calculate_vertical_projection_overlap_ratio
((
block_x0
,
block_y0
,
block_x1
,
block_y1
),
footnote_bbox
)
>=
0.8
:
if
block
not
in
need_remove_blocks
:
need_remove_blocks
.
append
(
block
)
break
return
need_remove_blocks
def
fix_interline_equation_overlap_text_blocks_with_hi_iou
(
all_bboxes
):
# 先提取所有text和interline block
text_blocks
=
[]
...
...
magic_pdf/resources/model_config/model_configs.yaml
View file @
c1ba9dcb
config
:
device
:
cpu
layout
:
True
formula
:
True
table_config
:
model
:
TableMaster
is_table_recog_enable
:
False
max_time
:
400
weights
:
layout
:
Layout/model_final.pth
mfd
:
MFD/weights.pt
mfr
:
MFR/unimernet_small
layoutlmv3
:
Layout/LayoutLMv3/model_final.pth
doclayout_yolo
:
Layout/YOLO/doclayout_yolo_ft.pt
yolo_v8_mfd
:
MFD/YOLO/yolo_v8_ft.pt
unimernet_small
:
MFR/unimernet_small
struct_eqtable
:
TabRec/StructEqTable
TableMaster
:
TabRec/TableMaster
\ No newline at end of file
tablemaster
:
TabRec/TableMaster
\ No newline at end of file
magic_pdf/tools/common.py
View file @
c1ba9dcb
...
...
@@ -46,10 +46,12 @@ def do_parse(
start_page_id
=
0
,
end_page_id
=
None
,
lang
=
None
,
layout_model
=
None
,
formula_enable
=
None
,
table_enable
=
None
,
):
if
debug_able
:
logger
.
warning
(
'debug mode is on'
)
# f_dump_content_list = True
f_draw_model_bbox
=
True
f_draw_line_sort_bbox
=
True
...
...
@@ -64,13 +66,16 @@ def do_parse(
if
parse_method
==
'auto'
:
jso_useful_key
=
{
'_pdf_type'
:
''
,
'model_list'
:
model_list
}
pipe
=
UNIPipe
(
pdf_bytes
,
jso_useful_key
,
image_writer
,
is_debug
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
)
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
elif
parse_method
==
'txt'
:
pipe
=
TXTPipe
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
)
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
elif
parse_method
==
'ocr'
:
pipe
=
OCRPipe
(
pdf_bytes
,
model_list
,
image_writer
,
is_debug
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
)
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
)
else
:
logger
.
error
(
'unknown parse method'
)
exit
(
1
)
...
...
magic_pdf/user_api.py
View file @
c1ba9dcb
...
...
@@ -101,11 +101,19 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
if
pdf_info_dict
is
None
or
pdf_info_dict
.
get
(
"_need_drop"
,
False
):
logger
.
warning
(
f
"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr"
)
if
input_model_is_empty
:
pdf_models
=
doc_analyze
(
pdf_bytes
,
ocr
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
)
layout_model
=
kwargs
.
get
(
"layout_model"
,
None
)
formula_enable
=
kwargs
.
get
(
"formula_enable"
,
None
)
table_enable
=
kwargs
.
get
(
"table_enable"
,
None
)
pdf_models
=
doc_analyze
(
pdf_bytes
,
ocr
=
True
,
start_page_id
=
start_page_id
,
end_page_id
=
end_page_id
,
lang
=
lang
,
layout_model
=
layout_model
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
,
)
pdf_info_dict
=
parse_pdf
(
parse_pdf_by_ocr
)
if
pdf_info_dict
is
None
:
raise
Exception
(
"Both parse_pdf_by_txt and parse_pdf_by_ocr failed."
)
...
...
old_docs/download_models.py
View file @
c1ba9dcb
...
...
@@ -5,16 +5,21 @@ import requests
from
modelscope
import
snapshot_download
def
download_json
(
url
):
# 下载JSON文件
response
=
requests
.
get
(
url
)
response
.
raise_for_status
()
# 检查请求是否成功
return
response
.
json
()
def
download_and_modify_json
(
url
,
local_filename
,
modifications
):
if
os
.
path
.
exists
(
local_filename
):
data
=
json
.
load
(
open
(
local_filename
))
config_version
=
data
.
get
(
'config_version'
,
'0.0.0'
)
if
config_version
<
'1.0.0'
:
data
=
download_json
(
url
)
else
:
# 下载JSON文件
response
=
requests
.
get
(
url
)
response
.
raise_for_status
()
# 检查请求是否成功
# 解析JSON内容
data
=
response
.
json
()
data
=
download_json
(
url
)
# 修改内容
for
key
,
value
in
modifications
.
items
():
...
...
@@ -26,13 +31,21 @@ def download_and_modify_json(url, local_filename, modifications):
if
__name__
==
'__main__'
:
model_dir
=
snapshot_download
(
'opendatalab/PDF-Extract-Kit'
)
mineru_patterns
=
[
"models/Layout/LayoutLMv3/*"
,
"models/Layout/YOLO/*"
,
"models/MFD/YOLO/*"
,
"models/MFR/unimernet_small/*"
,
"models/TabRec/TableMaster/*"
,
"models/TabRec/StructEqTable/*"
,
]
model_dir
=
snapshot_download
(
'opendatalab/PDF-Extract-Kit-1.0'
,
allow_patterns
=
mineru_patterns
)
layoutreader_model_dir
=
snapshot_download
(
'ppaanngggg/layoutreader'
)
model_dir
=
model_dir
+
'/models'
print
(
f
'model_dir is:
{
model_dir
}
'
)
print
(
f
'layoutreader_model_dir is:
{
layoutreader_model_dir
}
'
)
json_url
=
'https://gitee.com/myhloli/MinerU/raw/
master
/magic-pdf.template.json'
json_url
=
'https://gitee.com/myhloli/MinerU/raw/
dev
/magic-pdf.template.json'
config_file_name
=
'magic-pdf.json'
home_dir
=
os
.
path
.
expanduser
(
'~'
)
config_file
=
os
.
path
.
join
(
home_dir
,
config_file_name
)
...
...
old_docs/download_models_hf.py
View file @
c1ba9dcb
...
...
@@ -5,16 +5,21 @@ import requests
from
huggingface_hub
import
snapshot_download
def
download_json
(
url
):
# 下载JSON文件
response
=
requests
.
get
(
url
)
response
.
raise_for_status
()
# 检查请求是否成功
return
response
.
json
()
def
download_and_modify_json
(
url
,
local_filename
,
modifications
):
if
os
.
path
.
exists
(
local_filename
):
data
=
json
.
load
(
open
(
local_filename
))
config_version
=
data
.
get
(
'config_version'
,
'0.0.0'
)
if
config_version
<
'1.0.0'
:
data
=
download_json
(
url
)
else
:
# 下载JSON文件
response
=
requests
.
get
(
url
)
response
.
raise_for_status
()
# 检查请求是否成功
# 解析JSON内容
data
=
response
.
json
()
data
=
download_json
(
url
)
# 修改内容
for
key
,
value
in
modifications
.
items
():
...
...
@@ -26,13 +31,28 @@ def download_and_modify_json(url, local_filename, modifications):
if
__name__
==
'__main__'
:
model_dir
=
snapshot_download
(
'opendatalab/PDF-Extract-Kit'
)
layoutreader_model_dir
=
snapshot_download
(
'hantian/layoutreader'
)
mineru_patterns
=
[
"models/Layout/LayoutLMv3/*"
,
"models/Layout/YOLO/*"
,
"models/MFD/YOLO/*"
,
"models/MFR/unimernet_small/*"
,
"models/TabRec/TableMaster/*"
,
"models/TabRec/StructEqTable/*"
,
]
model_dir
=
snapshot_download
(
'opendatalab/PDF-Extract-Kit-1.0'
,
allow_patterns
=
mineru_patterns
)
layoutreader_pattern
=
[
"*.json"
,
"*.safetensors"
,
]
layoutreader_model_dir
=
snapshot_download
(
'hantian/layoutreader'
,
allow_patterns
=
layoutreader_pattern
)
model_dir
=
model_dir
+
'/models'
print
(
f
'model_dir is:
{
model_dir
}
'
)
print
(
f
'layoutreader_model_dir is:
{
layoutreader_model_dir
}
'
)
json_url
=
'https://github.com/opendatalab/MinerU/raw/
master
/magic-pdf.template.json'
json_url
=
'https://github.com/opendatalab/MinerU/raw/
dev
/magic-pdf.template.json'
config_file_name
=
'magic-pdf.json'
home_dir
=
os
.
path
.
expanduser
(
'~'
)
config_file
=
os
.
path
.
join
(
home_dir
,
config_file_name
)
...
...
projects/gradio_app/app.py
View file @
c1ba9dcb
...
...
@@ -23,7 +23,7 @@ def read_fn(path):
return
disk_rw
.
read
(
os
.
path
.
basename
(
path
),
AbsReaderWriter
.
MODE_BIN
)
def
parse_pdf
(
doc_path
,
output_dir
,
end_page_id
,
is_ocr
):
def
parse_pdf
(
doc_path
,
output_dir
,
end_page_id
,
is_ocr
,
layout_mode
,
formula_enable
,
table_enable
,
language
):
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
try
:
...
...
@@ -42,6 +42,10 @@ def parse_pdf(doc_path, output_dir, end_page_id, is_ocr):
parse_method
,
False
,
end_page_id
=
end_page_id
,
layout_model
=
layout_mode
,
formula_enable
=
formula_enable
,
table_enable
=
table_enable
,
lang
=
language
,
)
return
local_md_dir
,
file_name
except
Exception
as
e
:
...
...
@@ -93,9 +97,10 @@ def replace_image_with_base64(markdown_text, image_dir_path):
return
re
.
sub
(
pattern
,
replace
,
markdown_text
)
def
to_markdown
(
file_path
,
end_pages
,
is_ocr
):
def
to_markdown
(
file_path
,
end_pages
,
is_ocr
,
layout_mode
,
formula_enable
,
table_enable
,
language
):
# 获取识别的md文件以及压缩包文件路径
local_md_dir
,
file_name
=
parse_pdf
(
file_path
,
'./output'
,
end_pages
-
1
,
is_ocr
)
local_md_dir
,
file_name
=
parse_pdf
(
file_path
,
'./output'
,
end_pages
-
1
,
is_ocr
,
layout_mode
,
formula_enable
,
table_enable
,
language
)
archive_zip_path
=
os
.
path
.
join
(
"./output"
,
compute_sha256
(
local_md_dir
)
+
".zip"
)
zip_archive_success
=
compress_directory_to_zip
(
local_md_dir
,
archive_zip_path
)
if
zip_archive_success
==
0
:
...
...
@@ -138,6 +143,27 @@ with open("header.html", "r") as file:
header
=
file
.
read
()
latin_lang
=
[
'af'
,
'az'
,
'bs'
,
'cs'
,
'cy'
,
'da'
,
'de'
,
'es'
,
'et'
,
'fr'
,
'ga'
,
'hr'
,
'hu'
,
'id'
,
'is'
,
'it'
,
'ku'
,
'la'
,
'lt'
,
'lv'
,
'mi'
,
'ms'
,
'mt'
,
'nl'
,
'no'
,
'oc'
,
'pi'
,
'pl'
,
'pt'
,
'ro'
,
'rs_latin'
,
'sk'
,
'sl'
,
'sq'
,
'sv'
,
'sw'
,
'tl'
,
'tr'
,
'uz'
,
'vi'
,
'french'
,
'german'
]
arabic_lang
=
[
'ar'
,
'fa'
,
'ug'
,
'ur'
]
cyrillic_lang
=
[
'ru'
,
'rs_cyrillic'
,
'be'
,
'bg'
,
'uk'
,
'mn'
,
'abq'
,
'ady'
,
'kbd'
,
'ava'
,
'dar'
,
'inh'
,
'che'
,
'lbe'
,
'lez'
,
'tab'
]
devanagari_lang
=
[
'hi'
,
'mr'
,
'ne'
,
'bh'
,
'mai'
,
'ang'
,
'bho'
,
'mah'
,
'sck'
,
'new'
,
'gom'
,
'sa'
,
'bgc'
]
other_lang
=
[
'ch'
,
'en'
,
'korean'
,
'japan'
,
'chinese_cht'
,
'ta'
,
'te'
,
'ka'
]
all_lang
=
[
""
]
all_lang
.
extend
([
*
other_lang
,
*
latin_lang
,
*
arabic_lang
,
*
cyrillic_lang
,
*
devanagari_lang
])
if
__name__
==
"__main__"
:
with
gr
.
Blocks
()
as
demo
:
gr
.
HTML
(
header
)
...
...
@@ -145,8 +171,14 @@ if __name__ == "__main__":
with
gr
.
Column
(
variant
=
'panel'
,
scale
=
5
):
pdf_show
=
gr
.
Markdown
()
max_pages
=
gr
.
Slider
(
1
,
10
,
5
,
step
=
1
,
label
=
"Max convert pages"
)
with
gr
.
Row
()
as
bu_flow
:
is_ocr
=
gr
.
Checkbox
(
label
=
"Force enable OCR"
)
with
gr
.
Row
():
layout_mode
=
gr
.
Dropdown
([
"layoutlmv3"
,
"doclayout_yolo"
],
label
=
"Layout model"
,
value
=
"layoutlmv3"
)
language
=
gr
.
Dropdown
(
all_lang
,
label
=
"Language"
,
value
=
""
)
with
gr
.
Row
():
formula_enable
=
gr
.
Checkbox
(
label
=
"Enable formula recognition"
,
value
=
True
)
is_ocr
=
gr
.
Checkbox
(
label
=
"Force enable OCR"
,
value
=
False
)
table_enable
=
gr
.
Checkbox
(
label
=
"Enable table recognition(test)"
,
value
=
False
)
with
gr
.
Row
():
change_bu
=
gr
.
Button
(
"Convert"
)
clear_bu
=
gr
.
ClearButton
([
pdf_show
],
value
=
"Clear"
)
pdf_show
=
PDF
(
label
=
"Please upload pdf"
,
interactive
=
True
,
height
=
800
)
...
...
@@ -166,7 +198,8 @@ if __name__ == "__main__":
latex_delimiters
=
latex_delimiters
,
line_breaks
=
True
)
with
gr
.
Tab
(
"Markdown text"
):
md_text
=
gr
.
TextArea
(
lines
=
45
,
show_copy_button
=
True
)
change_bu
.
click
(
fn
=
to_markdown
,
inputs
=
[
pdf_show
,
max_pages
,
is_ocr
],
outputs
=
[
md
,
md_text
,
output_file
,
pdf_show
])
change_bu
.
click
(
fn
=
to_markdown
,
inputs
=
[
pdf_show
,
max_pages
,
is_ocr
,
layout_mode
,
formula_enable
,
table_enable
,
language
],
outputs
=
[
md
,
md_text
,
output_file
,
pdf_show
])
clear_bu
.
add
([
md
,
pdf_show
,
md_text
,
output_file
,
is_ocr
])
demo
.
launch
()
\ No newline at end of file
demo
.
launch
(
server_name
=
"0.0.0.0"
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment