Commit 705f8331 authored by myhloli's avatar myhloli
Browse files

feat: enhance model download options with source and type parameters

parent d55db09e
...@@ -24,7 +24,7 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes ...@@ -24,7 +24,7 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
) )
@click.option( @click.option(
'-o', '-o',
'--output-dir', '--output',
'output_dir', 'output_dir',
type=click.Path(), type=click.Path(),
required=True, required=True,
...@@ -118,16 +118,14 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes ...@@ -118,16 +118,14 @@ from .common import do_parse, read_fn, pdf_suffixes, image_suffixes
default=None, default=None,
) )
@click.option( @click.option(
'-vm', '--vram',
'--virtual-vram',
'virtual_vram', 'virtual_vram',
type=int, type=int,
help='Upper limit of GPU memory occupied by a single process. Adapted only for the case where the backend is set to "pipeline". ', help='Upper limit of GPU memory occupied by a single process. Adapted only for the case where the backend is set to "pipeline". ',
default=None, default=None,
) )
@click.option( @click.option(
'-r', '--source',
'--repo',
'model_source', 'model_source',
type=click.Choice(['huggingface', 'modelscope', 'local']), type=click.Choice(['huggingface', 'modelscope', 'local']),
help=""" help="""
......
...@@ -58,43 +58,90 @@ def configure_model(model_dir, model_type): ...@@ -58,43 +58,90 @@ def configure_model(model_dir, model_type):
@click.command() @click.command()
def download_models(): @click.option(
'-s',
'--source',
'model_source',
type=click.Choice(['huggingface', 'modelscope']),
help="""
The source of the model repository.
""",
default=None,
)
@click.option(
'-m',
'--model_type',
'model_type',
type=click.Choice(['pipeline', 'vlm', 'all']),
help="""
The type of the model to download.
""",
default=None,
)
def download_models(model_source, model_type):
"""Download MinerU model files. """Download MinerU model files.
Supports downloading pipeline or VLM models from ModelScope or HuggingFace. Supports downloading pipeline or VLM models from ModelScope or HuggingFace.
""" """
# 交互式输入下载来源 # 如果未显式指定则交互式输入下载来源
source = click.prompt( if model_source is None:
model_source = click.prompt(
"Please select the model download source: ", "Please select the model download source: ",
type=click.Choice(['huggingface', 'modelscope']), type=click.Choice(['huggingface', 'modelscope']),
default='huggingface' default='huggingface'
) )
os.environ['MINERU_MODEL_SOURCE'] = source if os.getenv('MINERU_MODEL_SOURCE', None) is None:
os.environ['MINERU_MODEL_SOURCE'] = model_source
# 交互式输入模型类型 # 如果未显式指定则交互式输入模型类型
if model_type is None:
model_type = click.prompt( model_type = click.prompt(
"Please select the model type to download: ", "Please select the model type to download: ",
type=click.Choice(['pipeline', 'vlm']), type=click.Choice(['pipeline', 'vlm', 'all']),
default='pipeline' default='all'
) )
click.echo(f"Downloading {model_type} model from {source}...") click.echo(f"Downloading {model_type} model from {os.getenv('MINERU_MODEL_SOURCE', None)}...")
def download_pipeline_models():
"""下载Pipeline模型"""
model_paths = [
ModelPath.doclayout_yolo,
ModelPath.yolo_v8_mfd,
ModelPath.unimernet_small,
ModelPath.pytorch_paddle,
ModelPath.layout_reader,
ModelPath.slanet_plus
]
download_finish_path = ""
for model_path in model_paths:
click.echo(f"Downloading model: {model_path}")
download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode='pipeline')
click.echo(f"Pipeline models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, model_type)
def download_vlm_models():
"""下载VLM模型"""
download_finish_path = auto_download_and_get_model_root_path("/", repo_mode='vlm')
click.echo(f"VLM models downloaded successfully to: {download_finish_path}")
configure_model(download_finish_path, model_type)
try: try:
download_finish_path = ""
if model_type == 'pipeline': if model_type == 'pipeline':
for model_path in [ModelPath.doclayout_yolo, ModelPath.yolo_v8_mfd, ModelPath.unimernet_small, ModelPath.pytorch_paddle, ModelPath.layout_reader, ModelPath.slanet_plus]: download_pipeline_models()
click.echo(f"Downloading model: {model_path}")
download_finish_path = auto_download_and_get_model_root_path(model_path, repo_mode=model_type)
elif model_type == 'vlm': elif model_type == 'vlm':
download_finish_path = auto_download_and_get_model_root_path("/", repo_mode=model_type) download_vlm_models()
click.echo(f"Models downloaded successfully to: {download_finish_path}") elif model_type == 'all':
configure_model(download_finish_path, model_type) download_pipeline_models()
download_vlm_models()
else:
click.echo(f"Unsupported model type: {model_type}", err=True)
sys.exit(1)
except Exception as e: except Exception as e:
click.echo(f"Download failed: {str(e)}", err=True) click.echo(f"Download failed: {str(e)}", err=True)
sys.exit(1) sys.exit(1)
if __name__ == '__main__': if __name__ == '__main__':
download_models() download_models()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment