Commit 3ed28c37 authored by wanglch's avatar wanglch
Browse files

Update train_multi_dcu.py

parent a9deb018
......@@ -104,7 +104,7 @@ def test_loop(dataloader, model):
batch_data = {k: v.to(device) for k, v in batch_data.items()}
with torch.no_grad():
# 如果你使用了 DataParallel,你可以通过访问 model.module 来获取原始模型
generated_tokens = model.generate(
generated_tokens = model.module.generate(
batch_data["input_ids"],
attention_mask=batch_data["attention_mask"],
max_length=max_target_length,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment