Commit 4faeb38b authored by Ubuntu's avatar Ubuntu
Browse files

Fix loss loss logging for multi-gpu compatibility

parent 25f73add
...@@ -529,10 +529,10 @@ def main(): ...@@ -529,10 +529,10 @@ def main():
label_ids = label_ids.to(device) label_ids = label_ids.to(device)
loss, _ = model(input_ids, segment_ids, input_mask, label_ids) loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
total_tr_loss += loss.item() total_tr_loss += loss.sum().item() # sum() is to account for multi-gpu support.
nb_tr_examples += input_ids.size(0) nb_tr_examples += input_ids.size(0)
model.zero_grad() model.zero_grad()
loss.backward() loss.sum().backward() # sum() is to account for multi-gpu support.
optimizer.step() optimizer.step()
global_step += 1 global_step += 1
...@@ -573,7 +573,7 @@ def main(): ...@@ -573,7 +573,7 @@ def main():
label_ids = label_ids.to('cpu').numpy() label_ids = label_ids.to('cpu').numpy()
tmp_eval_accuracy = accuracy(logits, label_ids) tmp_eval_accuracy = accuracy(logits, label_ids)
eval_loss += tmp_eval_loss.item() eval_loss += tmp_eval_loss.sum().item()
eval_accuracy += tmp_eval_accuracy eval_accuracy += tmp_eval_accuracy
nb_eval_examples += input_ids.size(0) nb_eval_examples += input_ids.size(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment