Commit 705f8331 authored by myhloli's avatar myhloli
Browse files

feat: enhance model download options with source and type parameters

parent d55db09e
......@@ -24,7 +24,7 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
)
@click.option(
'-o',
'--output-dir',
'--output',
'output_dir',
type=click.Path(),
required=True,
......@@ -118,16 +118,14 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
default=None,
)
@click.option(
'-vm',
'--virtual-vram',
'--vram',
'virtual_vram',
type=int,
help='Upper limit of GPU memory occupied by a single process. Adapted only for the case where the backend is set to "pipeline". ',
default=None,
)
@click.option(
'-r',
'--repo',
'--source',
'model_source',
type=click.Choice(['huggingface', 'modelscope', 'local']),
help="""
......
......@@ -58,43 +58,90 @@ def configure_model(model_dir, model_type):
@click.command()
def download_models():
@click.option(
'-s',
'--source',
'model_source',
type=click.Choice(['huggingface', 'modelscope']),
help="""
The source of the model repository.
""",
default=None,
)
@click.option(
'-m',
'--model_type',
'model_type',
type=click.Choice(['pipeline', 'vlm', 'all']),
help="""
The type of the model to download.
""",
default=None,
)
def download_models(model_source, model_type):
"""Download MinerU model files.
Supports downloading pipeline or VLM models from ModelScope or HuggingFace.
"""
# 交互式输入下载来源
source = click.prompt(
"Please select the model download source: ",
type=click.Choice(['huggingface', 'modelscope']),
default='huggingface'
)
os.environ['MINERU_MODEL_SOURCE'] = source
# 交互式输入模型类型
model_type = click.prompt(
"Please select the model type to download: ",
type=click.Choice(['pipeline', 'vlm']),
default='pipeline'
)
# 如果未显式指定则交互式输入下载来源
if model_source is None:
model_source = click.prompt(
"Please select the model download source: ",
type=click.Choice(['huggingface', 'modelscope']),
default='huggingface'
)
if os.getenv('MINERU_MODEL_SOURCE', None) is None:
os.environ['MINERU_MODEL_SOURCE'] = model_source
# 如果未显式指定则交互式输入模型类型
if model_type is None:
model_type = click.prompt(
"Please select the model type to download: ",
type=click.Choice(['pipeline', 'vlm', 'all']),
default='all'
)
click.echo(f"Downloading {model_type} model from {os.getenv('MINERU_MODEL_SOURCE', None)}...")
def download_pipeline_models():
"""下载Pipeline模型"""
model_paths = [
ModelPath.doclayout_yolo,
ModelPath.yolo_v8_mfd,
ModelPath.unimernet_small,
ModelPath.pytorch_paddle,
ModelPath.layout_reader,
ModelPath.slanet_plus
]
download_finish_path = ""
for model_path in model_paths:
click.echo(f"Downloading model: {model_path}")
download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode='pipeline')
click.echo(f"Pipeline models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, model_type)
click.echo(f"Downloading {model_type} model from {source}...")
def download_vlm_models():
"""下载VLM模型"""
download_finish_path = auto_download_and_get_model_root_path("/", repo_mode='vlm')
click.echo(f"VLM models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, model_type)
try:
download_finish_path = ""
if model_type == 'pipeline':
for model_path in [ModelPath.doclayout_yolo, ModelPath.yolo_v8_mfd, ModelPath.unimernet_small, ModelPath.pytorch_paddle, ModelPath.layout_reader, ModelPath.slanet_plus]:
click.echo(f"Downloading model: {model_path}")
download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode=model_type)
download_pipeline_models()
elif model_type == 'vlm':
download_finish_path = auto_download_and_get_model_root_path("/", repo_mode=model_type)
click.echo(f"Models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, model_type)
download_vlm_models()
elif model_type == 'all':
download_pipeline_models()
download_vlm_models()
else:
click.echo(f"Unsupported model type: {model_type}", err=True)
sys.exit(1)
except Exception as e:
click.echo(f"Download failed: {str(e)}", err=True)
sys.exit(1)
if __name__ == '__main__':
download_models()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment