Commit 7bb8f0e9 authored by myhloli's avatar myhloli
Browse files

refactor: streamline model path handling and enhance file retrieval logic

parent 0039d113
...@@ -67,28 +67,6 @@ def parse_bucket_key(s3_full_path: str): ...@@ -67,28 +67,6 @@ def parse_bucket_key(s3_full_path: str):
return bucket, key return bucket, key
def get_local_models_dir():
config = read_config()
models_dir = config.get('models-dir')
if models_dir is None:
logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
return '/tmp/models'
else:
return models_dir
def get_local_layoutreader_model_dir():
config = read_config()
layoutreader_model_dir = config.get('layoutreader-model-dir')
if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
home_dir = os.path.expanduser('~')
layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
return layoutreader_at_modelscope_dir_path
else:
return layoutreader_model_dir
def get_device(): def get_device():
device_mode = os.getenv('MINERU_DEVICE_MODE', None) device_mode = os.getenv('MINERU_DEVICE_MODE', None)
if device_mode is not None: if device_mode is not None:
......
...@@ -9,10 +9,8 @@ from ...model.mfd.yolo_v8 import YOLOv8MFDModel ...@@ -9,10 +9,8 @@ from ...model.mfd.yolo_v8 import YOLOv8MFDModel
from ...model.mfr.unimernet.Unimernet import UnimernetModel from ...model.mfr.unimernet.Unimernet import UnimernetModel
from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
from ...model.table.rapid_table import RapidTableModel from ...model.table.rapid_table import RapidTableModel
from ...utils.enum_class import ModelPath
doclayout_yolo = "Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt" from ...utils.models_download_utils import get_file_from_repos
yolo_v8_mfd = "MFD/YOLO/yolo_v8_ft.pt"
unimernet_small = "MFR/unimernet_hf_small_2503"
def table_model_init(lang=None): def table_model_init(lang=None):
...@@ -150,14 +148,14 @@ class MineruPipelineModel: ...@@ -150,14 +148,14 @@ class MineruPipelineModel:
self.mfd_model = atom_model_manager.get_atom_model( self.mfd_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.MFD, atom_model_name=AtomicModel.MFD,
mfd_weights=str( mfd_weights=str(
os.path.join(models_dir, yolo_v8_mfd) os.path.join(models_dir, get_file_from_repos(ModelPath.yolo_v8_mfd))
), ),
device=self.device, device=self.device,
) )
# 初始化公式解析模型 # 初始化公式解析模型
mfr_weight_dir = str( mfr_weight_dir = str(
os.path.join(models_dir, unimernet_small) os.path.join(models_dir, get_file_from_repos(ModelPath.unimernet_small))
) )
self.mfr_model = atom_model_manager.get_atom_model( self.mfr_model = atom_model_manager.get_atom_model(
...@@ -170,7 +168,7 @@ class MineruPipelineModel: ...@@ -170,7 +168,7 @@ class MineruPipelineModel:
self.layout_model = atom_model_manager.get_atom_model( self.layout_model = atom_model_manager.get_atom_model(
atom_model_name=AtomicModel.Layout, atom_model_name=AtomicModel.Layout,
doclayout_yolo_weights=str( doclayout_yolo_weights=str(
os.path.join(models_dir, doclayout_yolo) os.path.join(models_dir, get_file_from_repos(ModelPath.doclayout_yolo))
), ),
device=self.device, device=self.device,
) )
......
...@@ -9,7 +9,9 @@ import numpy as np ...@@ -9,7 +9,9 @@ import numpy as np
import yaml import yaml
from loguru import logger from loguru import logger
from mineru.backend.pipeline.config_reader import get_device, get_local_models_dir from mineru.backend.pipeline.config_reader import get_device
from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import get_file_from_repos
from ....utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image from ....utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
from .tools.infer.predict_system import TextSystem from .tools.infer.predict_system import TextSystem
from .tools.infer import pytorchocr_utility as utility from .tools.infer import pytorchocr_utility as utility
...@@ -74,9 +76,11 @@ class PytorchPaddleOCR(TextSystem): ...@@ -74,9 +76,11 @@ class PytorchPaddleOCR(TextSystem):
with open(models_config_path) as file: with open(models_config_path) as file:
config = yaml.safe_load(file) config = yaml.safe_load(file)
det, rec, dict_file = get_model_params(self.lang, config) det, rec, dict_file = get_model_params(self.lang, config)
ocr_models_dir = os.path.join(get_local_models_dir(), 'OCR', 'paddleocr_torch') ocr_models_dir = ModelPath.pytorch_paddle
kwargs['det_model_path'] = os.path.join(ocr_models_dir, det) det_model_path = get_file_from_repos(f"{ocr_models_dir}/{det}")
kwargs['rec_model_path'] = os.path.join(ocr_models_dir, rec) rec_model_path = get_file_from_repos(f"{ocr_models_dir}/{rec}")
kwargs['det_model_path'] = det_model_path
kwargs['rec_model_path'] = rec_model_path
kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file) kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
# kwargs['rec_batch_num'] = 8 # kwargs['rec_batch_num'] = 8
......
...@@ -7,8 +7,9 @@ from typing import List ...@@ -7,8 +7,9 @@ from typing import List
import torch import torch
from loguru import logger from loguru import logger
from mineru.backend.pipeline.config_reader import get_device, get_local_layoutreader_model_dir from mineru.backend.pipeline.config_reader import get_device
from mineru.utils.enum_class import BlockType from mineru.utils.enum_class import BlockType, ModelPath
from mineru.utils.models_download_utils import get_file_from_repos
def sort_blocks_by_bbox(blocks, page_w, page_h, footnote_blocks): def sort_blocks_by_bbox(blocks, page_w, page_h, footnote_blocks):
...@@ -187,7 +188,7 @@ def model_init(model_name: str): ...@@ -187,7 +188,7 @@ def model_init(model_name: str):
device = torch.device(device_name) device = torch.device(device_name)
if model_name == 'layoutreader': if model_name == 'layoutreader':
# 检测modelscope的缓存目录是否存在 # 检测modelscope的缓存目录是否存在
layoutreader_model_dir = get_local_layoutreader_model_dir() layoutreader_model_dir = get_file_from_repos(ModelPath.layout_reader)
if os.path.exists(layoutreader_model_dir): if os.path.exists(layoutreader_model_dir):
model = LayoutLMv3ForTokenClassification.from_pretrained( model = LayoutLMv3ForTokenClassification.from_pretrained(
layoutreader_model_dir layoutreader_model_dir
......
...@@ -43,3 +43,15 @@ class MakeMode: ...@@ -43,3 +43,15 @@ class MakeMode:
MM_MD = 'mm_markdown' MM_MD = 'mm_markdown'
NLP_MD = 'nlp_markdown' NLP_MD = 'nlp_markdown'
STANDARD_FORMAT = 'standard_format' STANDARD_FORMAT = 'standard_format'
class ModelPath:
pipeline_root_modelscope = "OpenDataLab/PDF-Extract-Kit-1.0"
pipeline_root_hf = "opendatalab/PDF-Extract-Kit-1.0"
doclayout_yolo = "models/Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt"
yolo_v8_mfd = "models/MFD/YOLO/yolo_v8_ft.pt"
unimernet_small = "models/MFR/unimernet_hf_small_2503"
pytorch_paddle = "models/OCR/paddleocr_torch"
layout_reader = "models/ReadingOrder/layout_reader"
vlm_root_hf = "opendatalab/MinerU-VLM-1.0"
vlm_root_modelscope = "OpenDataLab/MinerU-VLM-1.0"
\ No newline at end of file
import os
import hashlib
import requests
from typing import List, Union
from huggingface_hub import hf_hub_download, model_info
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from mineru.utils.enum_class import ModelPath
def _sha256sum(path, chunk_size=8192):
h = hashlib.sha256()
with open(path, "rb") as f:
while True:
chunk = f.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> Union[str, str]:
"""
支持文件或目录的可靠下载。
- 如果输入文件: 返回本地文件绝对路径
- 如果输入目录: 返回本地缓存下与 relative_path 同结构的相对路径字符串
:param repo_mode: 指定仓库模式,'pipeline' 或 'vlm'
:param relative_path: 文件或目录相对路径
:return: 本地文件绝对路径或相对路径
"""
model_source = os.getenv('MINERU_MODEL_SOURCE', None)
# 建立仓库模式到路径的映射
repo_mapping = {
'pipeline': {
'huggingface': ModelPath.pipeline_root_hf,
'modelscope': ModelPath.pipeline_root_modelscope,
'default': ModelPath.pipeline_root_hf
},
'vlm': {
'huggingface': ModelPath.vlm_root_hf,
'modelscope': ModelPath.vlm_root_modelscope,
'default': ModelPath.vlm_root_hf
}
}
if repo_mode not in repo_mapping:
raise ValueError(f"Unsupported repo_mode: {repo_mode}, must be 'pipeline' or 'vlm'")
# 如果没有指定model_source或值不是'modelscope',则使用默认值
repo = repo_mapping[repo_mode].get(model_source, repo_mapping[repo_mode]['default'])
input_clean = relative_path.strip('/')
# 获取huggingface云端仓库文件树
try:
# 获取仓库信息,包含文件元数据
info = model_info(repo, files_metadata=True)
# 构建文件字典
siblings_dict = {f.rfilename: f for f in info.siblings}
except Exception as e:
siblings_dict = {}
print(f"[Warn] 获取 Huggingface 仓库结构失败,错误: {e}")
# 1. 文件还是目录拓展
if input_clean in siblings_dict and not siblings_dict[input_clean].rfilename.endswith("/"):
is_file = True
all_paths = [input_clean]
else:
is_file = False
all_paths = [k for k in siblings_dict if k.startswith(input_clean + "/") and not k.endswith("/")]
# 若获取不到siblings(如 Huggingface 失败,直接按输入处理)
if not all_paths:
is_file = os.path.splitext(input_clean)[1] != ""
all_paths = [input_clean] if is_file else []
cache_home = str(HUGGINGFACE_HUB_CACHE)
# 判断主逻辑
output_files = []
# ---- Huggingface 分支 ----
hf_ok = False
for relpath in all_paths:
ok = False
if relpath in siblings_dict:
meta = siblings_dict[relpath]
sha256 = ""
if meta.lfs:
sha256 = meta.lfs.sha256
try:
# 不允许下载线上文件,只寻找本地文件
file_path = hf_hub_download(repo_id=repo, filename=relpath, local_files_only=True)
if sha256 and os.path.exists(file_path):
if _sha256sum(file_path) == sha256:
ok = True
output_files.append(file_path)
except Exception as e:
print(f"[Info] Huggingface {relpath} 获取失败: {e}")
if not hf_ok:
file_path = hf_hub_download(repo_id=repo, filename=relpath, force_download=False)
print("file_path = ", file_path)
if sha256 and _sha256sum(file_path) != sha256:
raise ValueError(f"Huggingface下载后校验失败: {relpath}")
ok = True
output_files.append(file_path)
hf_ok = hf_ok and ok
# ---- ModelScope 分支 ----
for relpath in all_paths:
if hf_ok:
break
if "/" in repo:
org_name, model_name = repo.split("/", 1)
else:
org_name, model_name = "modelscope", repo
# 目录结构: 缓存/home/modelscope-fallback/org/model/相对路径
target_dir = os.path.join(cache_home, "modelscope-fallback", org_name, model_name, os.path.dirname(relpath))
os.makedirs(target_dir, exist_ok=True)
local_path = os.path.join(target_dir, os.path.basename(relpath))
remote_len = 0
sha256 = ""
try:
get_meta_url = f"https://www.modelscope.cn/api/v1/models/{org_name}/{model_name}/repo/raw?Revision=master&FilePath={relpath}&Needmeta=true"
resp = requests.get(get_meta_url, timeout=15)
if resp.ok:
remote_len = resp.json()["Data"]["MetaContent"]["Size"]
sha256 = resp.json()["Data"]["MetaContent"]["Sha256"]
except Exception as e:
print(f"[Info] modelscope {relpath} 获取失败: {e}")
ok_local = False
if remote_len > 0 and os.path.exists(local_path):
if sha256 == _sha256sum(local_path):
output_files.append(local_path)
ok_local = True
if not ok_local:
try:
modelscope_url = f"https://www.modelscope.cn/api/v1/models/{org_name}/{model_name}/repo?Revision=master&FilePath={relpath}"
with requests.get(modelscope_url, stream=True, timeout=30) as resp:
resp.raise_for_status()
with open(local_path, 'wb') as f:
for chunk in resp.iter_content(1024*1024):
if chunk:
f.write(chunk)
if remote_len == 0 or os.path.getsize(local_path) == remote_len:
output_files.append(local_path)
ok_local = True
except Exception as e:
print(f"[Error] ModelScope下载失败: {relpath} {e}")
if not output_files:
raise FileNotFoundError(f"{relative_path} 在 Huggingface 和 ModelScope 都未能获取")
if is_file:
return output_files[0]
else:
# 输入是文件,只返回路径字符串
return os.path.dirname(os.path.abspath(output_files[0]))
if __name__ == '__main__':
path1 = get_file_from_repos("models/README.md")
print("本地文件绝对路径:", path1)
path2 = get_file_from_repos("models/OCR/paddleocr_torch/")
print("本地文件绝对路径:", path2)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment