Commit 3d0ce953 authored by wanglch's avatar wanglch
Browse files

Update umt5_summary.py

parent b84362fa
......@@ -9,6 +9,7 @@ os.environ["HIP_VISIBLE_DEVICES"] = "4,5"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
# 根据用户路径修改对应路径即可
model_checkpoint = "/umt5/umt5_base"
trained_model_weights = '/umt5/saves/train_dtk_weights/epoch_1_valid_rouge_23.4347_model_dtk_weights.bin'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment