Commit 57fe5c63 authored by wanglch's avatar wanglch
Browse files

Update multi_dcu_train.py

parent 3d7eb65b
...@@ -81,7 +81,7 @@ def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss): ...@@ -81,7 +81,7 @@ def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
model.train() model.train()
for batch, batch_data in enumerate(dataloader, start=1): for batch, batch_data in enumerate(dataloader, start=1):
batch_data = batch_data.to(device) batch_data = {k: v.to(device) for k, v in batch_data.items()}
outputs = model(**batch_data) outputs = model(**batch_data)
loss = outputs.loss loss = outputs.loss
loss = loss.mean() loss = loss.mean()
...@@ -101,7 +101,7 @@ def test_loop(dataloader, model): ...@@ -101,7 +101,7 @@ def test_loop(dataloader, model):
model.eval() model.eval()
for batch_data in tqdm(dataloader): for batch_data in tqdm(dataloader):
batch_data = batch_data.to(device) batch_data = {k: v.to(device) for k, v in batch_data.items()}
with torch.no_grad(): with torch.no_grad():
# 如果你使用了 DataParallel,你可以通过访问 model.module 来获取原始模型 # 如果你使用了 DataParallel,你可以通过访问 model.module 来获取原始模型
generated_tokens = model.module.generate( generated_tokens = model.module.generate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment