import time import argparse import logging from datetime import datetime from modelscope import snapshot_download from requests.exceptions import ConnectionError from urllib3.exceptions import MaxRetryError def setup_logging(log_file): """配置日志记录""" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_file), logging.StreamHandler() ] ) return logging.getLogger(__name__) def download_model_with_retry(model_name, cache_dir, max_retries=100, retry_delay=10): logger = logging.getLogger(__name__) retries = 0 while retries < max_retries: try: logger.info(f"尝试下载模型 (第 {retries + 1} 次)...") model_dir = snapshot_download(model_name, cache_dir=cache_dir) logger.info("模型下载成功!") return model_dir except (ConnectionError, MaxRetryError) as e: retries += 1 logger.error(f"下载失败: {str(e)}") if retries < max_retries: logger.info(f"{retry_delay}秒后重试...") time.sleep(retry_delay) # retry_delay *= 2 # 指数退避 else: logger.error(f"已达到最大重试次数 {max_retries},下载失败。") raise except Exception as e: logger.error(f"发生未知错误: {str(e)}") raise def main(): # 设置命令行参数 parser = argparse.ArgumentParser(description='下载模型脚本') parser.add_argument('--model-id', type=str, required=True, help='要下载的模型ID,例如 deepseek-ai/DeepSeek-R1-0528') parser.add_argument('--cache-dir', type=str, required=True, help='模型缓存目录路径') parser.add_argument('--log-dir', type=str, default='./log_downloads', help='日志文件目录路径') parser.add_argument('--max-retries', type=int, default=100, help='最大重试次数,默认为100') parser.add_argument('--retry-delay', type=int, default=10, help='初始重试延迟时间(秒),默认为10秒') args = parser.parse_args() # 创建日志文件名 model_id_simple = args.model_id.replace('/', '_') current_date = datetime.now().strftime('%Y%m%d') log_file = f"{args.log_dir}/{model_id_simple}_{current_date}.log" # 配置日志 logger = setup_logging(log_file) try: logger.info(f"开始下载模型: {args.model_id}") logger.info(f"缓存目录: {args.cache_dir}") logger.info(f"日志文件: {log_file}") model_dir = download_model_with_retry( args.model_id, cache_dir=args.cache_dir, max_retries=args.max_retries, retry_delay=args.retry_delay ) except Exception as e: logger.error(f"最终下载失败: {str(e)}") else: logger.info(f"模型已下载到: {model_dir}") if __name__ == "__main__": main()