Commit 42eac2d4 authored by wangkx1's avatar wangkx1
Browse files

Update model_downloader.py

parent f0129d66
......@@ -7,7 +7,7 @@ from requests.exceptions import ConnectionError
from urllib3.exceptions import MaxRetryError
def setup_logging(log_file):
"""������־��¼"""
"""配置日志记录"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
......@@ -23,52 +23,52 @@ def download_model_with_retry(model_name, cache_dir, max_retries=100, retry_dela
retries = 0
while retries < max_retries:
try:
logger.info(f"��������� (�� {retries + 1} ��)...")
logger.info(f"尝试下载模型 (第 {retries + 1} )...")
model_dir = snapshot_download(model_name, cache_dir=cache_dir)
logger.info("ģ�����سɹ�!")
logger.info("模型下载成功!")
return model_dir
except (ConnectionError, MaxRetryError) as e:
retries += 1
logger.error(f"����ʧ��: {str(e)}")
logger.error(f"下载失败: {str(e)}")
if retries < max_retries:
logger.info(f"{retry_delay}�������...")
logger.info(f"{retry_delay}秒后重试...")
time.sleep(retry_delay)
retry_delay *= 2 # ָ���˱�
retry_delay *= 2 # 指数退避
else:
logger.error(f"�Ѵﵽ������Դ��� {max_retries}������ʧ�ܡ�")
logger.error(f"已达到最大重试次数 {max_retries},下载失败。")
raise
except Exception as e:
logger.error(f"����δ֪����: {str(e)}")
logger.error(f"发生未知错误: {str(e)}")
raise
def main():
# �����������
parser = argparse.ArgumentParser(description='����ģ�ͽű�')
# 设置命令行参数
parser = argparse.ArgumentParser(description='下载模型脚本')
parser.add_argument('--model-id', type=str, required=True,
help='Ҫ���ص�ģ��ID������ deepseek-ai/DeepSeek-R1-0528')
help='要下载的模型ID,例如 deepseek-ai/DeepSeek-R1-0528')
parser.add_argument('--cache-dir', type=str, required=True,
help='ģ�ͻ���Ŀ¼·��')
help='模型缓存目录路径')
parser.add_argument('--log-dir', type=str, default='./log_downloads',
help='��־�ļ�Ŀ¼·��')
help='日志文件目录路径')
parser.add_argument('--max-retries', type=int, default=100,
help='������Դ�����Ĭ��Ϊ100')
help='最大重试次数,默认为100')
parser.add_argument('--retry-delay', type=int, default=10,
help='��ʼ�����ӳ�ʱ��(��)��Ĭ��Ϊ10��')
help='初始重试延迟时间(秒),默认为10秒')
args = parser.parse_args()
# ������־�ļ���
# 创建日志文件名
model_id_simple = args.model_id.replace('/', '_')
current_date = datetime.now().strftime('%Y%m%d')
log_file = f"{args.log_dir}/{model_id_simple}_{current_date}.log"
# ������־
# 配置日志
logger = setup_logging(log_file)
try:
logger.info(f"��ʼ����ģ��: {args.model_id}")
logger.info(f"����Ŀ¼: {args.cache_dir}")
logger.info(f"��־�ļ�: {log_file}")
logger.info(f"开始下载模型: {args.model_id}")
logger.info(f"缓存目录: {args.cache_dir}")
logger.info(f"日志文件: {log_file}")
model_dir = download_model_with_retry(
args.model_id,
......@@ -77,9 +77,9 @@ def main():
retry_delay=args.retry_delay
)
except Exception as e:
logger.error(f"��������ʧ��: {str(e)}")
logger.error(f"最终下载失败: {str(e)}")
else:
logger.info(f"ģ�������ص�: {model_dir}")
logger.info(f"模型已下载到: {model_dir}")
if __name__ == "__main__":
main()
\ No newline at end of file
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment