Commit 57fe5c63 authored by wanglch's avatar wanglch
Browse files

Update multi_dcu_train.py

parent 3d7eb65b
......@@ -81,7 +81,7 @@ def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
model.train()
for batch, batch_data in enumerate(dataloader, start=1):
batch_data = batch_data.to(device)
batch_data = {k: v.to(device) for k, v in batch_data.items()}
outputs = model(**batch_data)
loss = outputs.loss
loss = loss.mean()
......@@ -101,7 +101,7 @@ def test_loop(dataloader, model):
model.eval()
for batch_data in tqdm(dataloader):
batch_data = batch_data.to(device)
batch_data = {k: v.to(device) for k, v in batch_data.items()}
with torch.no_grad():
# 如果你使用了 DataParallel,你可以通过访问 model.module 来获取原始模型
generated_tokens = model.module.generate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment