"vscode:/vscode.git/clone" did not exist on "d248e7686f8487438261f34c49217b88c2a6dfbe"
Commit a5866d29 authored by wanglch's avatar wanglch
Browse files

Update train_single_dcu.py

parent dd221315
...@@ -149,10 +149,10 @@ if __name__=='__main__': ...@@ -149,10 +149,10 @@ if __name__=='__main__':
beam_size = 4 beam_size = 4
no_repeat_ngram_size = 2 no_repeat_ngram_size = 2
train_data = LCSTS('/home/wanglch/projects/Umt5/data/lcsts_tsv/data1.tsv') train_data = LCSTS('../Umt5/data/lcsts_tsv/data1.tsv')
valid_data = LCSTS('/home/wanglch/projects/Umt5/data/lcsts_tsv/data2.tsv') valid_data = LCSTS('../Umt5/data/lcsts_tsv/data2.tsv')
model_checkpoint = "/home/wanglch/projects/Umt5/umt5_base" model_checkpoint = "../Umt5/umt5_base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, trust_remote_code=True) model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, trust_remote_code=True)
model = model.to(device) model = model.to(device)
...@@ -180,7 +180,7 @@ if __name__=='__main__': ...@@ -180,7 +180,7 @@ if __name__=='__main__':
if rouge_avg > best_avg_rouge: if rouge_avg > best_avg_rouge:
best_avg_rouge = rouge_avg best_avg_rouge = rouge_avg
print('saving new weights...\n') print('saving new weights...\n')
weight_path = f'/home/wanglch/projects/saves/utm5/train_dtk_weights/epoch_{t+1}_valid_rouge_{rouge_avg:0.4f}_model_dtk_weights.bin' weight_path = f'../saves/utm5/train_dtk_weights/epoch_{t+1}_valid_rouge_{rouge_avg:0.4f}_model_dtk_weights.bin'
torch.save(model.state_dict(), weight_path) torch.save(model.state_dict(), weight_path)
# 加载训练后的权重 # 加载训练后的权重
state_dict = torch.load(weight_path) state_dict = torch.load(weight_path)
...@@ -188,7 +188,7 @@ if __name__=='__main__': ...@@ -188,7 +188,7 @@ if __name__=='__main__':
# 获取当前的日期和时间 # 获取当前的日期和时间
now = datetime.now() now = datetime.now()
timestamp = now.strftime("%Y%m%d_%H%M%S") timestamp = now.strftime("%Y%m%d_%H%M%S")
new_model_path = f'/home/wanglch/projects/saves/utm5/train_dtk_weights/umt5_{timestamp}' new_model_path = f'../saves/utm5/train_dtk_weights/umt5_{timestamp}'
model.save_pretrained(new_model_path) model.save_pretrained(new_model_path)
tokenizer.save_pretrained(new_model_path) tokenizer.save_pretrained(new_model_path)
print("Done!") print("Done!")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment