Commit d6620b5e authored by wanglch's avatar wanglch
Browse files

Update umt5_summary.py

parent 3d0ce953
......@@ -11,6 +11,7 @@ print(f'Using {device} device')
# 根据用户路径修改对应路径即可
model_checkpoint = "/umt5/umt5_base"
# 训练权重需要预训练后载入
trained_model_weights = '/umt5/saves/train_dtk_weights/epoch_1_valid_rouge_23.4347_model_dtk_weights.bin'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment