Commit cacf79fa authored by myhloli's avatar myhloli
Browse files

refactor: add models_download.py for downloading and configuring MinerU models

parent a8747f1d
import json
import os
import sys
import click
import requests
from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
def download_json(url):
"""下载JSON文件"""
response = requests.get(url)
response.raise_for_status()
return response.json()
def download_and_modify_json(url, local_filename, modifications):
"""下载JSON并修改内容"""
if os.path.exists(local_filename):
data = json.load(open(local_filename))
config_version = data.get('config_version', '0.0.0')
if config_version < '1.3.0':
data = download_json(url)
else:
data = download_json(url)
# 修改内容
for key, value in modifications.items():
if key in data:
if isinstance(data[key], dict):
# 如果是字典,合并新值
data[key].update(value)
else:
# 否则直接替换
data[key] = value
# 保存修改后的内容
with open(local_filename, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def configure_model(model_dir, model_type):
"""配置模型"""
json_url = 'https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/mineru.template.json'
config_file_name = 'mineru.json'
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, config_file_name)
json_mods = {
'models-dir': {
f'{model_type}': model_dir
}
}
download_and_modify_json(json_url, config_file, json_mods)
print(f'The configuration file has been successfully configured, the path is: {config_file}')
@click.command()
def download_models():
"""下载MinerU模型文件。
支持从ModelScope或HuggingFace下载pipeline或VLM模型。
"""
# 交互式输入下载来源
source = click.prompt(
"Please select the model download source: ",
type=click.Choice(['huggingface', 'modelscope']),
default='huggingface'
)
os.environ['MINERU_MODEL_SOURCE'] = source
# 交互式输入模型类型
model_type = click.prompt(
"Please select the model type to download: ",
type=click.Choice(['pipeline', 'vlm']),
default='pipeline'
)
click.echo(f"Downloading {model_type} model from {source}...")
try:
download_finish_path = ""
if model_type == 'pipeline':
for model_path in [ModelPath.doclayout_yolo, ModelPath.yolo_v8_mfd, ModelPath.unimernet_small, ModelPath.pytorch_paddle, ModelPath.layout_reader, ModelPath.slanet_plus]:
download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode=model_type)
elif model_type == 'vlm':
download_finish_path = auto_download_and_get_model_root_path("/", repo_mode=model_type)
click.echo(f"Models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, model_type)
except Exception as e:
click.echo(f"Download failed: {str(e)}", err=True)
sys.exit(1)
if __name__ == '__main__':
download_models()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment