"tests/utils/test_offline.py" did not exist on "88a951e3cc00f56b94d9b93dbc35a3812cd88747"
Unverified Commit 91df4551 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Trainer] Make sure shown loss in distributed training is correctly averaged...

[Trainer] Make sure shown loss in distributed training is correctly averaged over all workers (#13681)

* push

* improve tr loss gather
parent 044eff5b
......@@ -1462,7 +1462,10 @@ class Trainer:
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log:
logs: Dict[str, float] = {}
tr_loss_scalar = tr_loss.item()
# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
# reset tr_loss to zero
tr_loss -= tr_loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment