Commit dcb1010d authored by wanglch's avatar wanglch
Browse files

Update single_dcu_train.py

parent 938ed894
...@@ -156,10 +156,10 @@ if __name__=='__main__': ...@@ -156,10 +156,10 @@ if __name__=='__main__':
# 如果不存在,则创建文件夹 # 如果不存在,则创建文件夹
os.makedirs(folder_path) os.makedirs(folder_path)
train_data = LCSTS('/umt5/data/lcsts_tsv/data1.tsv') train_data = LCSTS('./data/lcsts_tsv/data1.tsv')
valid_data = LCSTS('/umt5/data/lcsts_tsv/data2.tsv') valid_data = LCSTS('./data/lcsts_tsv/data2.tsv')
model_checkpoint = "/umt5/umt5_base" model_checkpoint = "./umt5_base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint) model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
...@@ -194,7 +194,7 @@ if __name__=='__main__': ...@@ -194,7 +194,7 @@ if __name__=='__main__':
if rouge_avg > best_avg_rouge: if rouge_avg > best_avg_rouge:
best_avg_rouge = rouge_avg best_avg_rouge = rouge_avg
print('saving new weights...\n') print('saving new weights...\n')
weight_path = f'/utm5/saves/train_dtk_weights/epoch_{t+1}_valid_rouge_{rouge_avg:0.4f}_model_dtk_weights.bin' weight_path = f'./saves/train_dtk_weights/epoch_{t+1}_valid_rouge_{rouge_avg:0.4f}_model_dtk_weights.bin'
torch.save(model.state_dict(), weight_path) torch.save(model.state_dict(), weight_path)
# 加载训练后的权重 # 加载训练后的权重
state_dict = torch.load(weight_path) state_dict = torch.load(weight_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment