model_downloader.py 3.33 KB
Newer Older
wangkaixiong's avatar
wangkaixiong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import time
import argparse
import logging
from datetime import datetime
from modelscope import snapshot_download
from requests.exceptions import ConnectionError
from urllib3.exceptions import MaxRetryError

def setup_logging(log_file):
    """������־��¼"""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler()
        ]
    )
    return logging.getLogger(__name__)

def download_model_with_retry(model_name, cache_dir, max_retries=100, retry_delay=10):
    logger = logging.getLogger(__name__)
    retries = 0
    while retries < max_retries:
        try:
            logger.info(f"��������� (�� {retries + 1} ��)...")
            model_dir = snapshot_download(model_name, cache_dir=cache_dir)
            logger.info("ģ�����سɹ�!")
            return model_dir
        except (ConnectionError, MaxRetryError) as e:
            retries += 1
            logger.error(f"����ʧ��: {str(e)}")
            if retries < max_retries:
                logger.info(f"{retry_delay}�������...")
                time.sleep(retry_delay)
                retry_delay *= 2  # ָ���˱�
            else:
                logger.error(f"�Ѵﵽ������Դ��� {max_retries}������ʧ�ܡ�")
                raise
        except Exception as e:
            logger.error(f"����δ֪����: {str(e)}")
            raise

def main():
    # �����������
    parser = argparse.ArgumentParser(description='����ģ�ͽű�')
    parser.add_argument('--model-id', type=str, required=True, 
                       help='Ҫ���ص�ģ��ID������ deepseek-ai/DeepSeek-R1-0528')
    parser.add_argument('--cache-dir', type=str, required=True,
                       help='ģ�ͻ���Ŀ¼·��')
    parser.add_argument('--log-dir', type=str, default='./log_downloads',
                       help='��־�ļ�Ŀ¼·��')
    parser.add_argument('--max-retries', type=int, default=100,
                       help='������Դ�����Ĭ��Ϊ100')
    parser.add_argument('--retry-delay', type=int, default=10,
                       help='��ʼ�����ӳ�ʱ��(��)��Ĭ��Ϊ10��')
    
    args = parser.parse_args()
    
    # ������־�ļ���
    model_id_simple = args.model_id.replace('/', '_')
    current_date = datetime.now().strftime('%Y%m%d')
    log_file = f"{args.log_dir}/{model_id_simple}_{current_date}.log"
    
    # ������־
    logger = setup_logging(log_file)
    
    try:
        logger.info(f"��ʼ����ģ��: {args.model_id}")
        logger.info(f"����Ŀ¼: {args.cache_dir}")
        logger.info(f"��־�ļ�: {log_file}")
        
        model_dir = download_model_with_retry(
            args.model_id, 
            cache_dir=args.cache_dir,
            max_retries=args.max_retries,
            retry_delay=args.retry_delay
        )
    except Exception as e:
        logger.error(f"��������ʧ��: {str(e)}")
    else:
        logger.info(f"ģ�������ص�: {model_dir}")

if __name__ == "__main__":
    main()