Commit c7c1e30e authored by myhloli's avatar myhloli
Browse files

feat: implement model downloading functions with logging

parent 43f21a77
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import sys import sys
import click import click
import requests import requests
from loguru import logger
from mineru.utils.enum_class import ModelPath from mineru.utils.enum_class import ModelPath
from mineru.utils.models_download_utils import auto_download_and_get_model_root_path from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
...@@ -57,6 +58,31 @@ def configure_model(model_dir, model_type): ...@@ -57,6 +58,31 @@ def configure_model(model_dir, model_type):
print(f'The configuration file has been successfully configured, the path is: {config_file}') print(f'The configuration file has been successfully configured, the path is: {config_file}')
def download_pipeline_models():
"""下载Pipeline模型"""
model_paths = [
ModelPath.doclayout_yolo,
ModelPath.yolo_v8_mfd,
ModelPath.unimernet_small,
ModelPath.pytorch_paddle,
ModelPath.layout_reader,
ModelPath.slanet_plus
]
download_finish_path = ""
for model_path in model_paths:
click.echo(f"Downloading model: {model_path}")
download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode='pipeline')
click.echo(f"Pipeline models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, "pipeline")
def download_vlm_models():
"""下载VLM模型"""
download_finish_path = auto_download_and_get_model_root_path("/", repo_mode='vlm')
click.echo(f"VLM models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, "vlm")
@click.command() @click.command()
@click.option( @click.option(
'-s', '-s',
...@@ -102,30 +128,7 @@ def download_models(model_source, model_type): ...@@ -102,30 +128,7 @@ def download_models(model_source, model_type):
default='all' default='all'
) )
click.echo(f"Downloading {model_type} model from {os.getenv('MINERU_MODEL_SOURCE', None)}...") logger.info(f"Downloading {model_type} model from {os.getenv('MINERU_MODEL_SOURCE', None)}...")
def download_pipeline_models():
"""下载Pipeline模型"""
model_paths = [
ModelPath.doclayout_yolo,
ModelPath.yolo_v8_mfd,
ModelPath.unimernet_small,
ModelPath.pytorch_paddle,
ModelPath.layout_reader,
ModelPath.slanet_plus
]
download_finish_path = ""
for model_path in model_paths:
click.echo(f"Downloading model: {model_path}")
download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode='pipeline')
click.echo(f"Pipeline models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, "pipeline")
def download_vlm_models():
"""下载VLM模型"""
download_finish_path = auto_download_and_get_model_root_path("/", repo_mode='vlm')
click.echo(f"VLM models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, "vlm")
try: try:
if model_type == 'pipeline': if model_type == 'pipeline':
...@@ -140,7 +143,7 @@ def download_models(model_source, model_type): ...@@ -140,7 +143,7 @@ def download_models(model_source, model_type):
sys.exit(1) sys.exit(1)
except Exception as e: except Exception as e:
click.echo(f"Download failed: {str(e)}", err=True) logger.exception(f"An error occurred while downloading models: {str(e)}")
sys.exit(1) sys.exit(1)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment