Commit 938ed894 authored by wanglch's avatar wanglch
Browse files

Update multi_dcu_train.py

parent 2cf5e2ac
......@@ -149,17 +149,17 @@ if __name__=='__main__':
beam_size = 4
no_repeat_ngram_size = 2
folder_path = "/saves/train_dtk_weights"
folder_path = "./saves/train_dtk_weights"
# 检查文件夹是否存在
if not os.path.exists(folder_path):
# 如果不存在,则创建文件夹
os.makedirs(folder_path)
train_data = LCSTS('/umt5/data/lcsts_tsv/data1.tsv')
valid_data = LCSTS('/umt5/data/lcsts_tsv/data2.tsv')
train_data = LCSTS('./data/lcsts_tsv/data1.tsv')
valid_data = LCSTS('./data/lcsts_tsv/data2.tsv')
model_checkpoint = "/umt5/umt5_base"
model_checkpoint = "./umt5_base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
......@@ -192,7 +192,7 @@ if __name__=='__main__':
if rouge_avg > best_avg_rouge:
best_avg_rouge = rouge_avg
print('saving new weights...\n')
weight_path = f'/utm5/saves/train_dtk_weights/epoch_{t+1}_valid_rouge_{rouge_avg:0.4f}_model_dtk_weights.bin'
weight_path = f'./saves/train_dtk_weights/epoch_{t+1}_valid_rouge_{rouge_avg:0.4f}_model_dtk_weights.bin'
torch.save(model.state_dict(), weight_path)
# 加载训练后的权重
state_dict = torch.load(weight_path)
......@@ -200,7 +200,7 @@ if __name__=='__main__':
# 获取当前的日期和时间
now = datetime.now()
timestamp = now.strftime("%Y%m%d_%H%M%S")
new_model_path = f'saves/umt5_{timestamp}'
new_model_path = f'./saves/umt5_{timestamp}'
model.module.save_pretrained(new_model_path)
tokenizer.save_pretrained(new_model_path)
print("Done!")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment