Commit 3ed28c37 authored by wanglch's avatar wanglch
Browse files

Update train_multi_dcu.py

parent a9deb018
...@@ -104,7 +104,7 @@ def test_loop(dataloader, model): ...@@ -104,7 +104,7 @@ def test_loop(dataloader, model):
batch_data = {k: v.to(device) for k, v in batch_data.items()} batch_data = {k: v.to(device) for k, v in batch_data.items()}
with torch.no_grad(): with torch.no_grad():
# 如果你使用了 DataParallel,你可以通过访问 model.module 来获取原始模型 # 如果你使用了 DataParallel,你可以通过访问 model.module 来获取原始模型
generated_tokens = model.generate( generated_tokens = model.module.generate(
batch_data["input_ids"], batch_data["input_ids"],
attention_mask=batch_data["attention_mask"], attention_mask=batch_data["attention_mask"],
max_length=max_target_length, max_length=max_target_length,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment