Commit 42eac2d4 authored by wangkx1's avatar wangkx1
Browse files

Update model_downloader.py

parent f0129d66
...@@ -7,7 +7,7 @@ from requests.exceptions import ConnectionError ...@@ -7,7 +7,7 @@ from requests.exceptions import ConnectionError
from urllib3.exceptions import MaxRetryError from urllib3.exceptions import MaxRetryError
def setup_logging(log_file): def setup_logging(log_file):
"""������־��¼""" """配置日志记录"""
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s', format='%(asctime)s - %(levelname)s - %(message)s',
...@@ -23,52 +23,52 @@ def download_model_with_retry(model_name, cache_dir, max_retries=100, retry_dela ...@@ -23,52 +23,52 @@ def download_model_with_retry(model_name, cache_dir, max_retries=100, retry_dela
retries = 0 retries = 0
while retries < max_retries: while retries < max_retries:
try: try:
logger.info(f"��������ģ�� (�� {retries + 1} ��)...") logger.info(f"尝试下载模型 (第 {retries + 1} )...")
model_dir = snapshot_download(model_name, cache_dir=cache_dir) model_dir = snapshot_download(model_name, cache_dir=cache_dir)
logger.info("ģ�����سɹ�!") logger.info("模型下载成功!")
return model_dir return model_dir
except (ConnectionError, MaxRetryError) as e: except (ConnectionError, MaxRetryError) as e:
retries += 1 retries += 1
logger.error(f"����ʧ��: {str(e)}") logger.error(f"下载失败: {str(e)}")
if retries < max_retries: if retries < max_retries:
logger.info(f"{retry_delay}�������...") logger.info(f"{retry_delay}秒后重试...")
time.sleep(retry_delay) time.sleep(retry_delay)
retry_delay *= 2 # ָ���˱� retry_delay *= 2 # 指数退避
else: else:
logger.error(f"�Ѵﵽ������Դ��� {max_retries}������ʧ�ܡ�") logger.error(f"已达到最大重试次数 {max_retries},下载失败。")
raise raise
except Exception as e: except Exception as e:
logger.error(f"����δ֪����: {str(e)}") logger.error(f"发生未知错误: {str(e)}")
raise raise
def main(): def main():
# ���������в��� # 设置命令行参数
parser = argparse.ArgumentParser(description='����ģ�ͽű�') parser = argparse.ArgumentParser(description='下载模型脚本')
parser.add_argument('--model-id', type=str, required=True, parser.add_argument('--model-id', type=str, required=True,
help='Ҫ���ص�ģ��ID������ deepseek-ai/DeepSeek-R1-0528') help='要下载的模型ID,例如 deepseek-ai/DeepSeek-R1-0528')
parser.add_argument('--cache-dir', type=str, required=True, parser.add_argument('--cache-dir', type=str, required=True,
help='ģ�ͻ���Ŀ¼·��') help='模型缓存目录路径')
parser.add_argument('--log-dir', type=str, default='./log_downloads', parser.add_argument('--log-dir', type=str, default='./log_downloads',
help='��־�ļ�Ŀ¼·��') help='日志文件目录路径')
parser.add_argument('--max-retries', type=int, default=100, parser.add_argument('--max-retries', type=int, default=100,
help='������Դ�����Ĭ��Ϊ100') help='最大重试次数,默认为100')
parser.add_argument('--retry-delay', type=int, default=10, parser.add_argument('--retry-delay', type=int, default=10,
help='��ʼ�����ӳ�ʱ��(��)��Ĭ��Ϊ10��') help='初始重试延迟时间(秒),默认为10秒')
args = parser.parse_args() args = parser.parse_args()
# ������־�ļ��� # 创建日志文件名
model_id_simple = args.model_id.replace('/', '_') model_id_simple = args.model_id.replace('/', '_')
current_date = datetime.now().strftime('%Y%m%d') current_date = datetime.now().strftime('%Y%m%d')
log_file = f"{args.log_dir}/{model_id_simple}_{current_date}.log" log_file = f"{args.log_dir}/{model_id_simple}_{current_date}.log"
# ������־ # 配置日志
logger = setup_logging(log_file) logger = setup_logging(log_file)
try: try:
logger.info(f"��ʼ����ģ��: {args.model_id}") logger.info(f"开始下载模型: {args.model_id}")
logger.info(f"����Ŀ¼: {args.cache_dir}") logger.info(f"缓存目录: {args.cache_dir}")
logger.info(f"��־�ļ�: {log_file}") logger.info(f"日志文件: {log_file}")
model_dir = download_model_with_retry( model_dir = download_model_with_retry(
args.model_id, args.model_id,
...@@ -77,9 +77,9 @@ def main(): ...@@ -77,9 +77,9 @@ def main():
retry_delay=args.retry_delay retry_delay=args.retry_delay
) )
except Exception as e: except Exception as e:
logger.error(f"��������ʧ��: {str(e)}") logger.error(f"最终下载失败: {str(e)}")
else: else:
logger.info(f"ģ�������ص�: {model_dir}") logger.info(f"模型已下载到: {model_dir}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment