Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
MinerU
Commits
7bb8f0e9
Commit
7bb8f0e9
authored
Jun 06, 2025
by
myhloli
Browse files
refactor: streamline model path handling and enhance file retrieval logic
parent
0039d113
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
183 additions
and
37 deletions
+183
-37
mineru/backend/pipeline/config_reader.py
mineru/backend/pipeline/config_reader.py
+0
-22
mineru/backend/pipeline/model_init.py
mineru/backend/pipeline/model_init.py
+5
-7
mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py
mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py
+8
-4
mineru/utils/block_sort.py
mineru/utils/block_sort.py
+4
-3
mineru/utils/enum_class.py
mineru/utils/enum_class.py
+13
-1
mineru/utils/models_download_utils.py
mineru/utils/models_download_utils.py
+153
-0
No files found.
mineru/backend/pipeline/config_reader.py
View file @
7bb8f0e9
...
@@ -67,28 +67,6 @@ def parse_bucket_key(s3_full_path: str):
...
@@ -67,28 +67,6 @@ def parse_bucket_key(s3_full_path: str):
return
bucket
,
key
return
bucket
,
key
def
get_local_models_dir
():
config
=
read_config
()
models_dir
=
config
.
get
(
'models-dir'
)
if
models_dir
is
None
:
logger
.
warning
(
f
"'models-dir' not found in
{
CONFIG_FILE_NAME
}
, use '/tmp/models' as default"
)
return
'/tmp/models'
else
:
return
models_dir
def
get_local_layoutreader_model_dir
():
config
=
read_config
()
layoutreader_model_dir
=
config
.
get
(
'layoutreader-model-dir'
)
if
layoutreader_model_dir
is
None
or
not
os
.
path
.
exists
(
layoutreader_model_dir
):
home_dir
=
os
.
path
.
expanduser
(
'~'
)
layoutreader_at_modelscope_dir_path
=
os
.
path
.
join
(
home_dir
,
'.cache/modelscope/hub/ppaanngggg/layoutreader'
)
logger
.
warning
(
f
"'layoutreader-model-dir' not exists, use
{
layoutreader_at_modelscope_dir_path
}
as default"
)
return
layoutreader_at_modelscope_dir_path
else
:
return
layoutreader_model_dir
def
get_device
():
def
get_device
():
device_mode
=
os
.
getenv
(
'MINERU_DEVICE_MODE'
,
None
)
device_mode
=
os
.
getenv
(
'MINERU_DEVICE_MODE'
,
None
)
if
device_mode
is
not
None
:
if
device_mode
is
not
None
:
...
...
mineru/backend/pipeline/model_init.py
View file @
7bb8f0e9
...
@@ -9,10 +9,8 @@ from ...model.mfd.yolo_v8 import YOLOv8MFDModel
...
@@ -9,10 +9,8 @@ from ...model.mfd.yolo_v8 import YOLOv8MFDModel
from
...model.mfr.unimernet.Unimernet
import
UnimernetModel
from
...model.mfr.unimernet.Unimernet
import
UnimernetModel
from
...model.ocr.paddleocr2pytorch.pytorch_paddle
import
PytorchPaddleOCR
from
...model.ocr.paddleocr2pytorch.pytorch_paddle
import
PytorchPaddleOCR
from
...model.table.rapid_table
import
RapidTableModel
from
...model.table.rapid_table
import
RapidTableModel
from
...utils.enum_class
import
ModelPath
doclayout_yolo
=
"Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt"
from
...utils.models_download_utils
import
get_file_from_repos
yolo_v8_mfd
=
"MFD/YOLO/yolo_v8_ft.pt"
unimernet_small
=
"MFR/unimernet_hf_small_2503"
def
table_model_init
(
lang
=
None
):
def
table_model_init
(
lang
=
None
):
...
@@ -150,14 +148,14 @@ class MineruPipelineModel:
...
@@ -150,14 +148,14 @@ class MineruPipelineModel:
self
.
mfd_model
=
atom_model_manager
.
get_atom_model
(
self
.
mfd_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
MFD
,
atom_model_name
=
AtomicModel
.
MFD
,
mfd_weights
=
str
(
mfd_weights
=
str
(
os
.
path
.
join
(
models_dir
,
yolo_v8_mfd
)
os
.
path
.
join
(
models_dir
,
get_file_from_repos
(
ModelPath
.
yolo_v8_mfd
)
)
),
),
device
=
self
.
device
,
device
=
self
.
device
,
)
)
# 初始化公式解析模型
# 初始化公式解析模型
mfr_weight_dir
=
str
(
mfr_weight_dir
=
str
(
os
.
path
.
join
(
models_dir
,
unimernet_small
)
os
.
path
.
join
(
models_dir
,
get_file_from_repos
(
ModelPath
.
unimernet_small
)
)
)
)
self
.
mfr_model
=
atom_model_manager
.
get_atom_model
(
self
.
mfr_model
=
atom_model_manager
.
get_atom_model
(
...
@@ -170,7 +168,7 @@ class MineruPipelineModel:
...
@@ -170,7 +168,7 @@ class MineruPipelineModel:
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
self
.
layout_model
=
atom_model_manager
.
get_atom_model
(
atom_model_name
=
AtomicModel
.
Layout
,
atom_model_name
=
AtomicModel
.
Layout
,
doclayout_yolo_weights
=
str
(
doclayout_yolo_weights
=
str
(
os
.
path
.
join
(
models_dir
,
doclayout_yolo
)
os
.
path
.
join
(
models_dir
,
get_file_from_repos
(
ModelPath
.
doclayout_yolo
)
)
),
),
device
=
self
.
device
,
device
=
self
.
device
,
)
)
...
...
mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py
View file @
7bb8f0e9
...
@@ -9,7 +9,9 @@ import numpy as np
...
@@ -9,7 +9,9 @@ import numpy as np
import
yaml
import
yaml
from
loguru
import
logger
from
loguru
import
logger
from
mineru.backend.pipeline.config_reader
import
get_device
,
get_local_models_dir
from
mineru.backend.pipeline.config_reader
import
get_device
from
mineru.utils.enum_class
import
ModelPath
from
mineru.utils.models_download_utils
import
get_file_from_repos
from
....utils.ocr_utils
import
check_img
,
preprocess_image
,
sorted_boxes
,
merge_det_boxes
,
update_det_boxes
,
get_rotate_crop_image
from
....utils.ocr_utils
import
check_img
,
preprocess_image
,
sorted_boxes
,
merge_det_boxes
,
update_det_boxes
,
get_rotate_crop_image
from
.tools.infer.predict_system
import
TextSystem
from
.tools.infer.predict_system
import
TextSystem
from
.tools.infer
import
pytorchocr_utility
as
utility
from
.tools.infer
import
pytorchocr_utility
as
utility
...
@@ -74,9 +76,11 @@ class PytorchPaddleOCR(TextSystem):
...
@@ -74,9 +76,11 @@ class PytorchPaddleOCR(TextSystem):
with
open
(
models_config_path
)
as
file
:
with
open
(
models_config_path
)
as
file
:
config
=
yaml
.
safe_load
(
file
)
config
=
yaml
.
safe_load
(
file
)
det
,
rec
,
dict_file
=
get_model_params
(
self
.
lang
,
config
)
det
,
rec
,
dict_file
=
get_model_params
(
self
.
lang
,
config
)
ocr_models_dir
=
os
.
path
.
join
(
get_local_models_dir
(),
'OCR'
,
'paddleocr_torch'
)
ocr_models_dir
=
ModelPath
.
pytorch_paddle
kwargs
[
'det_model_path'
]
=
os
.
path
.
join
(
ocr_models_dir
,
det
)
det_model_path
=
get_file_from_repos
(
f
"
{
ocr_models_dir
}
/
{
det
}
"
)
kwargs
[
'rec_model_path'
]
=
os
.
path
.
join
(
ocr_models_dir
,
rec
)
rec_model_path
=
get_file_from_repos
(
f
"
{
ocr_models_dir
}
/
{
rec
}
"
)
kwargs
[
'det_model_path'
]
=
det_model_path
kwargs
[
'rec_model_path'
]
=
rec_model_path
kwargs
[
'rec_char_dict_path'
]
=
os
.
path
.
join
(
root_dir
,
'pytorchocr'
,
'utils'
,
'resources'
,
'dict'
,
dict_file
)
kwargs
[
'rec_char_dict_path'
]
=
os
.
path
.
join
(
root_dir
,
'pytorchocr'
,
'utils'
,
'resources'
,
'dict'
,
dict_file
)
# kwargs['rec_batch_num'] = 8
# kwargs['rec_batch_num'] = 8
...
...
mineru/utils/block_sort.py
View file @
7bb8f0e9
...
@@ -7,8 +7,9 @@ from typing import List
...
@@ -7,8 +7,9 @@ from typing import List
import
torch
import
torch
from
loguru
import
logger
from
loguru
import
logger
from
mineru.backend.pipeline.config_reader
import
get_device
,
get_local_layoutreader_model_dir
from
mineru.backend.pipeline.config_reader
import
get_device
from
mineru.utils.enum_class
import
BlockType
from
mineru.utils.enum_class
import
BlockType
,
ModelPath
from
mineru.utils.models_download_utils
import
get_file_from_repos
def
sort_blocks_by_bbox
(
blocks
,
page_w
,
page_h
,
footnote_blocks
):
def
sort_blocks_by_bbox
(
blocks
,
page_w
,
page_h
,
footnote_blocks
):
...
@@ -187,7 +188,7 @@ def model_init(model_name: str):
...
@@ -187,7 +188,7 @@ def model_init(model_name: str):
device
=
torch
.
device
(
device_name
)
device
=
torch
.
device
(
device_name
)
if
model_name
==
'layoutreader'
:
if
model_name
==
'layoutreader'
:
# 检测modelscope的缓存目录是否存在
# 检测modelscope的缓存目录是否存在
layoutreader_model_dir
=
get_
local_
layoutreader
_model_dir
(
)
layoutreader_model_dir
=
get_
file_from_repos
(
ModelPath
.
layout
_
reader
)
if
os
.
path
.
exists
(
layoutreader_model_dir
):
if
os
.
path
.
exists
(
layoutreader_model_dir
):
model
=
LayoutLMv3ForTokenClassification
.
from_pretrained
(
model
=
LayoutLMv3ForTokenClassification
.
from_pretrained
(
layoutreader_model_dir
layoutreader_model_dir
...
...
mineru/utils/enum_class.py
View file @
7bb8f0e9
...
@@ -43,3 +43,15 @@ class MakeMode:
...
@@ -43,3 +43,15 @@ class MakeMode:
MM_MD
=
'mm_markdown'
MM_MD
=
'mm_markdown'
NLP_MD
=
'nlp_markdown'
NLP_MD
=
'nlp_markdown'
STANDARD_FORMAT
=
'standard_format'
STANDARD_FORMAT
=
'standard_format'
class
ModelPath
:
pipeline_root_modelscope
=
"OpenDataLab/PDF-Extract-Kit-1.0"
pipeline_root_hf
=
"opendatalab/PDF-Extract-Kit-1.0"
doclayout_yolo
=
"models/Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt"
yolo_v8_mfd
=
"models/MFD/YOLO/yolo_v8_ft.pt"
unimernet_small
=
"models/MFR/unimernet_hf_small_2503"
pytorch_paddle
=
"models/OCR/paddleocr_torch"
layout_reader
=
"models/ReadingOrder/layout_reader"
vlm_root_hf
=
"opendatalab/MinerU-VLM-1.0"
vlm_root_modelscope
=
"OpenDataLab/MinerU-VLM-1.0"
\ No newline at end of file
mineru/utils/models_download_utils.py
0 → 100644
View file @
7bb8f0e9
import
os
import
hashlib
import
requests
from
typing
import
List
,
Union
from
huggingface_hub
import
hf_hub_download
,
model_info
from
huggingface_hub.constants
import
HUGGINGFACE_HUB_CACHE
from
mineru.utils.enum_class
import
ModelPath
def
_sha256sum
(
path
,
chunk_size
=
8192
):
h
=
hashlib
.
sha256
()
with
open
(
path
,
"rb"
)
as
f
:
while
True
:
chunk
=
f
.
read
(
chunk_size
)
if
not
chunk
:
break
h
.
update
(
chunk
)
return
h
.
hexdigest
()
def
get_file_from_repos
(
relative_path
:
str
,
repo_mode
=
'pipeline'
)
->
Union
[
str
,
str
]:
"""
支持文件或目录的可靠下载。
- 如果输入文件: 返回本地文件绝对路径
- 如果输入目录: 返回本地缓存下与 relative_path 同结构的相对路径字符串
:param repo_mode: 指定仓库模式,'pipeline' 或 'vlm'
:param relative_path: 文件或目录相对路径
:return: 本地文件绝对路径或相对路径
"""
model_source
=
os
.
getenv
(
'MINERU_MODEL_SOURCE'
,
None
)
# 建立仓库模式到路径的映射
repo_mapping
=
{
'pipeline'
:
{
'huggingface'
:
ModelPath
.
pipeline_root_hf
,
'modelscope'
:
ModelPath
.
pipeline_root_modelscope
,
'default'
:
ModelPath
.
pipeline_root_hf
},
'vlm'
:
{
'huggingface'
:
ModelPath
.
vlm_root_hf
,
'modelscope'
:
ModelPath
.
vlm_root_modelscope
,
'default'
:
ModelPath
.
vlm_root_hf
}
}
if
repo_mode
not
in
repo_mapping
:
raise
ValueError
(
f
"Unsupported repo_mode:
{
repo_mode
}
, must be 'pipeline' or 'vlm'"
)
# 如果没有指定model_source或值不是'modelscope',则使用默认值
repo
=
repo_mapping
[
repo_mode
].
get
(
model_source
,
repo_mapping
[
repo_mode
][
'default'
])
input_clean
=
relative_path
.
strip
(
'/'
)
# 获取huggingface云端仓库文件树
try
:
# 获取仓库信息,包含文件元数据
info
=
model_info
(
repo
,
files_metadata
=
True
)
# 构建文件字典
siblings_dict
=
{
f
.
rfilename
:
f
for
f
in
info
.
siblings
}
except
Exception
as
e
:
siblings_dict
=
{}
print
(
f
"[Warn] 获取 Huggingface 仓库结构失败,错误:
{
e
}
"
)
# 1. 文件还是目录拓展
if
input_clean
in
siblings_dict
and
not
siblings_dict
[
input_clean
].
rfilename
.
endswith
(
"/"
):
is_file
=
True
all_paths
=
[
input_clean
]
else
:
is_file
=
False
all_paths
=
[
k
for
k
in
siblings_dict
if
k
.
startswith
(
input_clean
+
"/"
)
and
not
k
.
endswith
(
"/"
)]
# 若获取不到siblings(如 Huggingface 失败,直接按输入处理)
if
not
all_paths
:
is_file
=
os
.
path
.
splitext
(
input_clean
)[
1
]
!=
""
all_paths
=
[
input_clean
]
if
is_file
else
[]
cache_home
=
str
(
HUGGINGFACE_HUB_CACHE
)
# 判断主逻辑
output_files
=
[]
# ---- Huggingface 分支 ----
hf_ok
=
False
for
relpath
in
all_paths
:
ok
=
False
if
relpath
in
siblings_dict
:
meta
=
siblings_dict
[
relpath
]
sha256
=
""
if
meta
.
lfs
:
sha256
=
meta
.
lfs
.
sha256
try
:
# 不允许下载线上文件,只寻找本地文件
file_path
=
hf_hub_download
(
repo_id
=
repo
,
filename
=
relpath
,
local_files_only
=
True
)
if
sha256
and
os
.
path
.
exists
(
file_path
):
if
_sha256sum
(
file_path
)
==
sha256
:
ok
=
True
output_files
.
append
(
file_path
)
except
Exception
as
e
:
print
(
f
"[Info] Huggingface
{
relpath
}
获取失败:
{
e
}
"
)
if
not
hf_ok
:
file_path
=
hf_hub_download
(
repo_id
=
repo
,
filename
=
relpath
,
force_download
=
False
)
print
(
"file_path = "
,
file_path
)
if
sha256
and
_sha256sum
(
file_path
)
!=
sha256
:
raise
ValueError
(
f
"Huggingface下载后校验失败:
{
relpath
}
"
)
ok
=
True
output_files
.
append
(
file_path
)
hf_ok
=
hf_ok
and
ok
# ---- ModelScope 分支 ----
for
relpath
in
all_paths
:
if
hf_ok
:
break
if
"/"
in
repo
:
org_name
,
model_name
=
repo
.
split
(
"/"
,
1
)
else
:
org_name
,
model_name
=
"modelscope"
,
repo
# 目录结构: 缓存/home/modelscope-fallback/org/model/相对路径
target_dir
=
os
.
path
.
join
(
cache_home
,
"modelscope-fallback"
,
org_name
,
model_name
,
os
.
path
.
dirname
(
relpath
))
os
.
makedirs
(
target_dir
,
exist_ok
=
True
)
local_path
=
os
.
path
.
join
(
target_dir
,
os
.
path
.
basename
(
relpath
))
remote_len
=
0
sha256
=
""
try
:
get_meta_url
=
f
"https://www.modelscope.cn/api/v1/models/
{
org_name
}
/
{
model_name
}
/repo/raw?Revision=master&FilePath=
{
relpath
}
&Needmeta=true"
resp
=
requests
.
get
(
get_meta_url
,
timeout
=
15
)
if
resp
.
ok
:
remote_len
=
resp
.
json
()[
"Data"
][
"MetaContent"
][
"Size"
]
sha256
=
resp
.
json
()[
"Data"
][
"MetaContent"
][
"Sha256"
]
except
Exception
as
e
:
print
(
f
"[Info] modelscope
{
relpath
}
获取失败:
{
e
}
"
)
ok_local
=
False
if
remote_len
>
0
and
os
.
path
.
exists
(
local_path
):
if
sha256
==
_sha256sum
(
local_path
):
output_files
.
append
(
local_path
)
ok_local
=
True
if
not
ok_local
:
try
:
modelscope_url
=
f
"https://www.modelscope.cn/api/v1/models/
{
org_name
}
/
{
model_name
}
/repo?Revision=master&FilePath=
{
relpath
}
"
with
requests
.
get
(
modelscope_url
,
stream
=
True
,
timeout
=
30
)
as
resp
:
resp
.
raise_for_status
()
with
open
(
local_path
,
'wb'
)
as
f
:
for
chunk
in
resp
.
iter_content
(
1024
*
1024
):
if
chunk
:
f
.
write
(
chunk
)
if
remote_len
==
0
or
os
.
path
.
getsize
(
local_path
)
==
remote_len
:
output_files
.
append
(
local_path
)
ok_local
=
True
except
Exception
as
e
:
print
(
f
"[Error] ModelScope下载失败:
{
relpath
}
{
e
}
"
)
if
not
output_files
:
raise
FileNotFoundError
(
f
"
{
relative_path
}
在 Huggingface 和 ModelScope 都未能获取"
)
if
is_file
:
return
output_files
[
0
]
else
:
# 输入是文件,只返回路径字符串
return
os
.
path
.
dirname
(
os
.
path
.
abspath
(
output_files
[
0
]))
if
__name__
==
'__main__'
:
path1
=
get_file_from_repos
(
"models/README.md"
)
print
(
"本地文件绝对路径:"
,
path1
)
path2
=
get_file_from_repos
(
"models/OCR/paddleocr_torch/"
)
print
(
"本地文件绝对路径:"
,
path2
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment