Commit 109633be authored by wanglch's avatar wanglch
Browse files

Update umt5_summary.py

parent 7c115c82
......@@ -16,12 +16,6 @@ trained_model_weights = '/umt5/saves/train_dtk_weights/epoch_1_valid_rouge_23.43
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
# 检查是否有多个 GPU 可用
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
# 如果有多个 GPUs,使用 nn.DataParallel 包装模型
model = nn.DataParallel(model).to(device)
model.load_state_dict(torch.load(trained_model_weights))
model = model.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment