import time
import argparse
import logging
from datetime import datetime
from modelscope import snapshot_download
from requests.exceptions import ConnectionError
from urllib3.exceptions import MaxRetryError

def setup_logging(log_file):
    """配置日志记录"""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
    return logging.getLogger(__name__)

def download_model_with_retry(model_name, cache_dir, max_retries=100, retry_delay=10):
    logger = logging.getLogger(__name__)
    retries = 0
    while retries < max_retries:
        try:
            logger.info(f"尝试下载模型 (第 {retries + 1} 次)...")
            model_dir = snapshot_download(model_name, cache_dir=cache_dir)
            logger.info("模型下载成功!")
            return model_dir
        except (ConnectionError, MaxRetryError) as e:
            retries += 1
            logger.error(f"下载失败: {str(e)}")
            if retries < max_retries:
                logger.info(f"{retry_delay}秒后重试...")
                time.sleep(retry_delay)
                retry_delay *= 2  # 指数退避
            else:
                logger.error(f"已达到最大重试次数 {max_retries}，下载失败。")
                raise
        except Exception as e:
            logger.error(f"发生未知错误: {str(e)}")
            raise

def main():
    # 设置命令行参数
    parser = argparse.ArgumentParser(description='下载模型脚本')
    parser.add_argument('--model-id', type=str, required=True, 
                       help='要下载的模型ID，例如 deepseek-ai/DeepSeek-R1-0528')
    parser.add_argument('--cache-dir', type=str, required=True,
                       help='模型缓存目录路径')
    parser.add_argument('--log-dir', type=str, default='./log_downloads',
                       help='日志文件目录路径')
    parser.add_argument('--max-retries', type=int, default=100,
                       help='最大重试次数，默认为100')
    parser.add_argument('--retry-delay', type=int, default=10,
                       help='初始重试延迟时间(秒)，默认为10秒')
    
    args = parser.parse_args()
    
    # 创建日志文件名
    model_id_simple = args.model_id.replace('/', '_')
    current_date = datetime.now().strftime('%Y%m%d')
    log_file = f"{args.log_dir}/{model_id_simple}_{current_date}.log"
    
    # 配置日志
    logger = setup_logging(log_file)
    
    try:
        logger.info(f"开始下载模型: {args.model_id}")
        logger.info(f"缓存目录: {args.cache_dir}")
        logger.info(f"日志文件: {log_file}")
        
        model_dir = download_model_with_retry(
            args.model_id, 
            cache_dir=args.cache_dir,
            max_retries=args.max_retries,
            retry_delay=args.retry_delay
        )
    except Exception as e:
        logger.error(f"最终下载失败: {str(e)}")
    else:
        logger.info(f"模型已下载到: {model_dir}")

if __name__ == "__main__":
    main()
